Skip to content

Commit ae5ac65

Browse files
authored
Merge branch 'main' into olmo-core-2.3
2 parents 34db7a7 + 1b028de commit ae5ac65

40 files changed

Lines changed: 761 additions & 335 deletions

docs/inference_quickstart.md

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
## Inference Quickstart
2+
3+
This quickstart shows how to (1) initialize the OlmoEarth model, (2) obtain a satellite
4+
image suitable for input to the model, and (3) compute embeddings from that satellite
5+
image.
6+
7+
## Initializing the Model
8+
9+
First, setup a Python 3.12 environment. You can use your favorite Python package
10+
manager, but here is an example with uv:
11+
12+
```
13+
curl -LsSf https://astral.sh/uv/install.sh | sh
14+
cd /path/to/olmoearth_pretrain/
15+
uv sync
16+
```
17+
18+
Now we can use the `olmoearth_pretrain` library to initialize the model in pytorch.
19+
Below, we initialize the OlmoEarth-v1-Base model.
20+
21+
```python
22+
from olmoearth_pretrain.model_loader import ModelID, load_model
23+
model = load_model(ModelID.OLMOEARTH_V1_BASE)
24+
```
25+
26+
## Obtain Satellite Imagery
27+
28+
Here, we obtain one Sentinel-2 image from the ESA Copernicus Browser. If you want to
29+
apply the model on multiple images of a location, like a time series of Sentinel-1 and
30+
Sentinel-2 images, see the
31+
[OlmoEarth embedding](https://github.com/allenai/rslearn/blob/master/docs/examples/OlmoEarthEmbeddings.md).
32+
and [OlmoEarth fine-tuning](https://github.com/allenai/rslearn/blob/master/docs/examples/FinetuneOlmoEarth.md)
33+
guides in rslearn.
34+
35+
To download on image from the Copernicus Browser, follow these steps:
36+
37+
1. Navigate to https://browser.dataspace.copernicus.eu/. Press Login to sign up for an
38+
account and login.
39+
2. Go to the Search tab at the top-left. Check Sentinel-2, then check L2A. This selects
40+
Sentinel-2 L2A images, which are the type of Sentinel-2 images that OlmoEarth is
41+
pre-trained on.
42+
3. Modify the time range if desired. Also, use the area of interest tool at the top
43+
right to select a spatial area to search over.
44+
4. Then, press Search. We recommend looking through the results to find a less cloudy
45+
image. You can press Visualize to preview the satellite image in the Browser before
46+
downloading it. Once you are satisfied, press the download icon next to the image in
47+
the search results. Once the download is complete, unzip the file.
48+
49+
If you prefer to skip using the browser, you can download and unzip a Sentinel-2 image
50+
of Seattle:
51+
52+
```
53+
wget https://storage.googleapis.com/ai2-rslearn-projects-data/artifacts/example_sentinel2_l2a_scene_of_seattle.zip
54+
unzip example_sentinel2_l2a_scene_of_seattle.zip
55+
```
56+
57+
## Compute Embeddings
58+
59+
Finally, we load the image in Python, normalize it, and apply the model on it to
60+
compute embeddings.
61+
62+
First, we read all of the image bands at 10 m/pixel. We use the B02 band (one of the
63+
10 m/pixel bands) to determine the transform under which to read the remaining bands,
64+
since some are stored at lower resolutions. Note that here we assume that the .SAFE
65+
folder is in the working directory.
66+
67+
```python
68+
import glob
69+
import numpy as np
70+
import rasterio
71+
from olmoearth_pretrain.data.constants import Modality
72+
from rasterio.vrt import WarpedVRT
73+
from rasterio.enums import Resampling
74+
75+
# Get the JP2 filenames that we need to read, in the band order expected by OlmoEarth.
76+
fnames = []
77+
for band_name in Modality.SENTINEL2_L2A.band_order:
78+
fname = glob.glob(f"*.SAFE/GRANULE/*/IMG_DATA/*/*_{band_name}_*.jp2")[0]
79+
fnames.append(fname)
80+
81+
# Get the CRS and transform from the first band, which is B02.
82+
with rasterio.open(fnames[0]) as src:
83+
crs = src.crs
84+
transform = src.transform
85+
width = src.width
86+
height = src.height
87+
88+
# We limit the width/height to 512x512 in case there is limited memory.
89+
width = 512
90+
height = 512
91+
92+
# Now read all of the bands.
93+
image = np.zeros((len(fnames), height, width), dtype=np.int32)
94+
for band_idx, fname in enumerate(fnames):
95+
with rasterio.open(fname) as src:
96+
with rasterio.vrt.WarpedVRT(
97+
src,
98+
crs=crs,
99+
transform=transform,
100+
width=width,
101+
height=height,
102+
resampling=Resampling.bilinear,
103+
) as vrt:
104+
image[band_idx, :, :] = vrt.read(1)
105+
106+
# Rearrange to BHWTC.
107+
image = image.transpose(1, 2, 0)[None, :, :, None, :]
108+
```
109+
110+
Next, we normalize the image:
111+
112+
```python
113+
from olmoearth_pretrain.data.normalize import Normalizer, Strategy
114+
115+
normalizer = Normalizer(Strategy.COMPUTED)
116+
image = normalizer.normalize(Modality.SENTINEL2_L2A, image)
117+
```
118+
119+
Now we can apply the model on the image. We recommend applying it on inputs between
120+
1x1 and 128x128 in size. The patch size can be set between 1 and 8; smaller patch sizes
121+
generally perform better, but require more GPU time.
122+
123+
```python
124+
import torch
125+
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
126+
127+
device = torch.device("cuda")
128+
model.to(device)
129+
130+
# Run the model on the topleft 64x64 of the image.
131+
sample = MaskedOlmoEarthSample(
132+
sentinel2_l2a=torch.tensor(image[:, 0:64, 0:64, :, :], dtype=torch.float32, device=device),
133+
# The mask shape is BHWTS, where S is the number of band sets (3 for Sentinel-2).
134+
sentinel2_l2a_mask=torch.ones((1, 64, 64, 1, 3), dtype=torch.float32, device=device) * MaskValue.ONLINE_ENCODER.value,
135+
# The timestamps is (day of month 1-31, month 0-11, year).
136+
# The values here correspond to the date of our sample Sentinel-2 image of Seattle
137+
# (2025-08-22).
138+
timestamps=torch.tensor([22, 7, 2025], device=device)[None, None, :],
139+
)
140+
tokens_and_masks = model.encoder(
141+
sample, fast_pass=True, patch_size=4,
142+
)["tokens_and_masks"]
143+
# Get the Sentinel-2 features.
144+
modality_features = tokens_and_masks.sentinel2_l2a
145+
# Pool the features over the timestep and band set dimensions so we end up with a BHWC
146+
# feature map.
147+
pooled = modality_features.mean(dim=[3, 4])
148+
```

olmoearth_pretrain/convert_dataset_to_studio_format.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

olmoearth_pretrain/data/dataset.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ def read_h5_file(
903903
"""Read the h5 file."""
904904
if self.cache_dir is not None:
905905
cache_file_path = self.cache_dir / h5_file_path.name
906-
logger.info(f"Caching H5 file {h5_file_path} to {cache_file_path}")
906+
logger.debug(f"Caching H5 file {h5_file_path} to {cache_file_path}")
907907
if not cache_file_path.exists():
908908
self._apply_throttling()
909909
# Copy to a temp file first and then atomically rename it to avoid
@@ -932,10 +932,6 @@ def read_h5_file(
932932
or k in ["timestamps"]
933933
}
934934

935-
# Log the dtype for each modality
936-
for k, v in sample_dict.items():
937-
logger.debug(f"Modality {k} has dtype {v.dtype}")
938-
939935
if (
940936
missing_mask_group_name
941937
:= ConvertToH5py.missing_timesteps_mask_group_name

0 commit comments

Comments
 (0)