Skip to content

Commit b43eee0

Browse files
authored
Add normalization (#1)
Adding normalization
1 parent 8b6bf07 commit b43eee0

8 files changed

Lines changed: 1585 additions & 2 deletions

File tree

README.md

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ model = load_model_from_id(ModelID.OLMOEARTH_V1_BASE, load_weights=True)
5858
model_with_weights = load_model_from_id(ModelID.OLMOEARTH_V1_NANO, load_weights=True)
5959
```
6060

61-
6261
### Direct Model Initialization (Custom Configuration)
6362

6463
For custom configurations (e.g., custom modalities), you can directly instantiate the model class:
@@ -92,7 +91,83 @@ weights = torch.load("path/to/weights.pth")
9291
model.load_state_dict(weights)
9392
```
9493

94+
### Data Normalization
95+
96+
The model expects normalized input data. Use the `Normalizer` class to normalize your data before passing it to the model.
97+
98+
**Important:** Data must be provided with bands in the specific order expected by each modality. See the band order section below.
99+
100+
### Sample Code
101+
102+
```python
103+
import torch
104+
import numpy as np
105+
106+
from olmoearth_pretrain_minimal import load_model_from_id, ModelID, Normalizer
107+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality
108+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import MaskedOlmoEarthSample
109+
110+
# Initialize normalizer
111+
normalizer = Normalizer(std_multiplier=2.0)
112+
113+
# Prepare Sentinel-2 L2A data: (batch, height, width, time, bands)
114+
# Bands must match Modality.SENTINEL2_L2A.band_order (12 bands)
115+
sentinel2_data = np.random.rand(1, 128, 128, 12, 12).astype(np.float32)
116+
117+
# Normalize the data
118+
normalized_sentinel2 = normalizer.normalize(Modality.SENTINEL2_L2A, sentinel2_data)
119+
120+
model = load_model_from_id(ModelID.OLMOEARTH_V1_BASE, load_weights=True)
121+
model.eval()
122+
123+
# Create minimal sample (timestamps required, month must be long for embedding)
124+
timestamps = torch.zeros(1, 12, 3, dtype=torch.long)
125+
timestamps[:, :, 1] = torch.arange(12, dtype=torch.long) # months 0-11
126+
127+
sample = MaskedOlmoEarthSample(
128+
timestamps=timestamps,
129+
sentinel2_l2a=torch.from_numpy(normalized_sentinel2).float(),
130+
sentinel2_l2a_mask=torch.zeros(1, 128, 128, 12, dtype=torch.long),
131+
)
132+
133+
with torch.no_grad():
134+
output = model.encoder(sample, patch_size=8, input_res=10, fast_pass=True)
135+
```
136+
137+
### Expected Band Orders
138+
139+
The model expects data with bands in a specific order for each modality. Use `Modality.<MODALITY_NAME>.band_order` to get the correct order:
140+
141+
```python
142+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality
143+
144+
# Sentinel-2 L2A band order (12 bands)
145+
print(Modality.SENTINEL2_L2A.band_order)
146+
# ['B02', 'B03', 'B04', 'B08', 'B05', 'B06', 'B07', 'B8A', 'B11', 'B12', 'B01', 'B09']
147+
148+
# Sentinel-1 band order (2 bands)
149+
print(Modality.SENTINEL1.band_order)
150+
# ['vv', 'vh']
151+
152+
# Landsat band order (11 bands)
153+
print(Modality.LANDSAT.band_order)
154+
# ['B8', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B9', 'B10', 'B11']
155+
156+
# WorldCover band order (1 band)
157+
print(Modality.WORLDCOVER.band_order)
158+
# ['B1']
159+
160+
# SRTM band order (1 band)
161+
print(Modality.SRTM.band_order)
162+
# ['srtm']
163+
```
164+
165+
**Key points:**
166+
- The last dimension of your data array must match the band order exactly
167+
- For multitemporal modalities (Sentinel-2, Sentinel-1, Landsat), data shape is `(batch, height, width, time, bands)`
168+
- For single-temporal modalities (WorldCover, SRTM, etc.), data shape is `(batch, height, width, bands)`
169+
95170
### Note
96171

97-
For the full package with training and evaluation capabilities, see the main `olmoearth_pretrain` package.
172+
For the full package with training and evaluation capabilities, see the main [`olmoearth_pretrain`](https://github.com/allenai/olmoearth_pretrain) package.
98173

olmoearth_pretrain_minimal/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
load_model_from_path,
77
)
88
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1 import OlmoEarthPretrain_v1
9+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.data.normalize import Normalizer
910

1011
__all__ = [
1112
"OlmoEarthPretrain_v1",
1213
"ModelID",
1314
"load_model_from_id",
1415
"load_model_from_path",
16+
"Normalizer",
1517
]
1618

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
"""Data utilities for OlmoEarth Pretrain v1."""
2+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
"""Normalization configuration files."""
2+

0 commit comments

Comments
 (0)