@@ -58,7 +58,6 @@ model = load_model_from_id(ModelID.OLMOEARTH_V1_BASE, load_weights=True)
5858model_with_weights = load_model_from_id(ModelID.OLMOEARTH_V1_NANO , load_weights = True )
5959```
6060
61-
6261### Direct Model Initialization (Custom Configuration)
6362
6463For custom configurations (e.g., custom modalities), you can directly instantiate the model class:
@@ -92,7 +91,83 @@ weights = torch.load("path/to/weights.pth")
9291model.load_state_dict(weights)
9392```
9493
94+ ### Data Normalization
95+
96+ The model expects normalized input data. Use the ` Normalizer ` class to normalize your data before passing it to the model.
97+
98+ ** Important:** Data must be provided with bands in the specific order expected by each modality. See the band order section below.
99+
100+ ### Sample Code
101+
102+ ``` python
103+ import torch
104+ import numpy as np
105+
106+ from olmoearth_pretrain_minimal import load_model_from_id, ModelID, Normalizer
107+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality
108+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import MaskedOlmoEarthSample
109+
110+ # Initialize normalizer
111+ normalizer = Normalizer(std_multiplier = 2.0 )
112+
113+ # Prepare Sentinel-2 L2A data: (batch, height, width, time, bands)
114+ # Bands must match Modality.SENTINEL2_L2A.band_order (12 bands)
115+ sentinel2_data = np.random.rand(1 , 128 , 128 , 12 , 12 ).astype(np.float32)
116+
117+ # Normalize the data
118+ normalized_sentinel2 = normalizer.normalize(Modality.SENTINEL2_L2A , sentinel2_data)
119+
120+ model = load_model_from_id(ModelID.OLMOEARTH_V1_BASE , load_weights = True )
121+ model.eval()
122+
123+ # Create minimal sample (timestamps required, month must be long for embedding)
124+ timestamps = torch.zeros(1 , 12 , 3 , dtype = torch.long)
125+ timestamps[:, :, 1 ] = torch.arange(12 , dtype = torch.long) # months 0-11
126+
127+ sample = MaskedOlmoEarthSample(
128+ timestamps = timestamps,
129+ sentinel2_l2a = torch.from_numpy(normalized_sentinel2).float(),
130+ sentinel2_l2a_mask = torch.zeros(1 , 128 , 128 , 12 , dtype = torch.long),
131+ )
132+
133+ with torch.no_grad():
134+ output = model.encoder(sample, patch_size = 8 , input_res = 10 , fast_pass = True )
135+ ```
136+
137+ ### Expected Band Orders
138+
139+ The model expects data with bands in a specific order for each modality. Use ` Modality.<MODALITY_NAME>.band_order ` to get the correct order:
140+
141+ ``` python
142+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality
143+
144+ # Sentinel-2 L2A band order (12 bands)
145+ print (Modality.SENTINEL2_L2A .band_order)
146+ # ['B02', 'B03', 'B04', 'B08', 'B05', 'B06', 'B07', 'B8A', 'B11', 'B12', 'B01', 'B09']
147+
148+ # Sentinel-1 band order (2 bands)
149+ print (Modality.SENTINEL1 .band_order)
150+ # ['vv', 'vh']
151+
152+ # Landsat band order (11 bands)
153+ print (Modality.LANDSAT .band_order)
154+ # ['B8', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B9', 'B10', 'B11']
155+
156+ # WorldCover band order (1 band)
157+ print (Modality.WORLDCOVER .band_order)
158+ # ['B1']
159+
160+ # SRTM band order (1 band)
161+ print (Modality.SRTM .band_order)
162+ # ['srtm']
163+ ```
164+
165+ ** Key points:**
166+ - The last dimension of your data array must match the band order exactly
167+ - For multitemporal modalities (Sentinel-2, Sentinel-1, Landsat), data shape is ` (batch, height, width, time, bands) `
168+ - For single-temporal modalities (WorldCover, SRTM, etc.), data shape is ` (batch, height, width, bands) `
169+
95170### Note
96171
97- For the full package with training and evaluation capabilities, see the main ` olmoearth_pretrain ` package.
172+ For the full package with training and evaluation capabilities, see the main [ ` olmoearth_pretrain ` ] ( https://github.com/allenai/olmoearth_pretrain ) package.
98173
0 commit comments