Skip to content

Commit f576734

Browse files
committed
Add inference quickstart guide
1 parent 29c305d commit f576734

1 file changed

Lines changed: 164 additions & 0 deletions

File tree

docs/inference_quickstart.md

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
## Inference Quickstart
2+
3+
This quickstart shows how to (1) initialize the OlmoEarth model, (2) obtain a satellite
4+
image suitable for input to the model, and (3) compute embeddings from that satellite
5+
image.
6+
7+
## Initializing the Model
8+
9+
First, setup a Python 3.12 environment. You can use your favorite Python package
10+
manager, but here is an example with uv:
11+
12+
```
13+
curl -LsSf https://astral.sh/uv/install.sh | sh
14+
cd /path/to/olmoearth_pretrain/
15+
uv sync
16+
```
17+
18+
Now we can use the `olmoearth_pretrain` library to initialize the model in pytorch.
19+
Below, we initialize the OlmoEarth-v1-Base model.
20+
21+
```python
22+
# TBD, I'm assuming we will have model IDs or something like that to easily initialize
23+
# the model.
24+
import json
25+
from pathlib import Path
26+
27+
import torch
28+
29+
from olmo_core.config import Config
30+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
31+
32+
checkpoint_path = Path("/weka/dfive-default/helios/checkpoints/joer/phase2.0_base_lr0.0001_wd0.02/step667200")
33+
with (checkpoint_path / "config.json").open() as f:
34+
config_dict = json.load(f)
35+
model_config = Config.from_dict(config_dict["model"])
36+
37+
model = model_config.build()
38+
39+
train_module_dir = checkpoint_path / "model_and_optim"
40+
load_model_and_optim_state(str(train_module_dir), model)
41+
42+
device = torch.device("cuda")
43+
model.to(device)
44+
```
45+
46+
## Obtain Satellite Imagery
47+
48+
Here, we obtain one Sentinel-2 image from the ESA Copernicus Browser. If you want to
49+
apply the model on multiple images of a location, like a time series of Sentinel-1 and
50+
Sentinel-2 images, see the
51+
[OlmoEarth embedding](https://github.com/allenai/rslearn/blob/master/docs/examples/OlmoEarthEmbeddings.md).
52+
and [OlmoEarth fine-tuning](https://github.com/allenai/rslearn/blob/master/docs/examples/FinetuneOlmoEarth.md)
53+
guides in rslearn.
54+
55+
To download on image from the Copernicus Browser, follow these steps:
56+
57+
1. Navigate to https://browser.dataspace.copernicus.eu/. Press Login to sign up for an
58+
account and login.
59+
2. Go to the Search tab at the top-left. Check Sentinel-2, then check L2A. This selects
60+
Sentinel-2 L2A images, which are the type of Sentinel-2 images that OlmoEarth is
61+
pre-trained on.
62+
3. Modify the time range if desired. Also, use the area of interest tool at the top
63+
right to select a spatial area to search over.
64+
4. Then, press Search. We recommend looking through the results to find a less cloudy
65+
image. You can press Visualize to preview the satellite image in the Browser before
66+
downloading it. Once you are satisfied, press the download icon next to the image in
67+
the search results. Once the download is complete, unzip the file.
68+
69+
If you prefer to skip using the browser, you can download and unzip a Sentinel-2 image
70+
of Seattle:
71+
72+
```
73+
wget https://storage.googleapis.com/ai2-rslearn-projects-data/artifacts/example_sentinel2_l2a_scene_of_seattle.zip
74+
unzip example_sentinel2_l2a_scene_of_seattle.zip
75+
```
76+
77+
## Compute Embeddings
78+
79+
Finally, we load the image in Python, normalize it, and apply the model on it to
80+
compute embeddings.
81+
82+
First, we read all of the image bands at 10 m/pixel. We use the B02 band (one of the
83+
10 m/pixel bands) to determine the transform under which to read the remaining bands,
84+
since some are stored at lower resolutions. Note that here we assume that the .SAFE
85+
folder is in the working directory.
86+
87+
```python
88+
import glob
89+
import numpy as np
90+
import rasterio
91+
from olmoearth_pretrain.data.constants import Modality
92+
from rasterio.vrt import WarpedVRT
93+
from rasterio.enums import Resampling
94+
95+
# Get the JP2 filenames that we need to read, in the band order expected by OlmoEarth.
96+
fnames = []
97+
for band_name in Modality.SENTINEL2_L2A.band_order:
98+
fname = glob.glob(f"*.SAFE/GRANULE/*/IMG_DATA/*/*_{band_name}_*.jp2")[0]
99+
fnames.append(fname)
100+
101+
# Get the CRS and transform from the first band, which is B02.
102+
with rasterio.open(fnames[0]) as src:
103+
crs = src.crs
104+
transform = src.transform
105+
width = src.width
106+
height = src.height
107+
108+
# We limit the width/height to 512x512 in case there is limited memory.
109+
width = 512
110+
height = 512
111+
112+
# Now read all of the bands.
113+
image = np.zeros((len(fnames), height, width), dtype=np.int32)
114+
for band_idx, fname in enumerate(fnames):
115+
with rasterio.open(fname) as src:
116+
with rasterio.vrt.WarpedVRT(
117+
src,
118+
crs=crs,
119+
transform=transform,
120+
width=width,
121+
height=height,
122+
resampling=Resampling.bilinear,
123+
) as vrt:
124+
image[band_idx, :, :] = vrt.read(1)
125+
126+
# Rearrange to BHWTC.
127+
image = image.transpose(1, 2, 0)[None, :, :, None, :]
128+
```
129+
130+
Next, we normalize the image:
131+
132+
```python
133+
from olmoearth_pretrain.data.normalize import Normalizer, Strategy
134+
135+
normalizer = Normalizer(Strategy.COMPUTED)
136+
image = normalizer.normalize(Modality.SENTINEL2_L2A, image)
137+
```
138+
139+
Now we can apply the model on the image. We recommend applying it on inputs between
140+
1x1 and 128x128 in size. The patch size can be set between 1 and 8; smaller patch sizes
141+
generally perform better, but require more GPU time.
142+
143+
```python
144+
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
145+
146+
# Run the model on the topleft 64x64 of the image.
147+
sample = MaskedOlmoEarthSample(
148+
sentinel2_l2a=torch.tensor(image[:, 0:64, 0:64, :, :], dtype=torch.float32, device=device),
149+
# The mask shape is BHWTS, where S is the number of band sets (3 for Sentinel-2).
150+
sentinel2_l2a_mask=torch.ones((1, 64, 64, 1, 3), dtype=torch.float32, device=device) * MaskValue.ONLINE_ENCODER.value,
151+
# The timestamps is (day of month 1-31, month 0-11, year).
152+
# The values here correspond to the date of our sample Sentinel-2 image of Seattle
153+
# (2025-08-22).
154+
timestamps=torch.tensor([22, 7, 2025], device=device)[None, None, :],
155+
)
156+
tokens_and_masks = model.encoder(
157+
sample, fast_pass=True, patch_size=4,
158+
)["tokens_and_masks"]
159+
# Get the Sentinel-2 features.
160+
modality_features = tokens_and_masks.sentinel2_l2a
161+
# Pool the features over the timestep and band set dimensions so we end up with a BHWC
162+
# feature map.
163+
pooled = modality_features.mean(dim=[3, 4])
164+
```

0 commit comments

Comments
 (0)