Skip to content

Commit aebf485

Browse files
committed
add conversion script
1 parent 2e8d5ba commit aebf485

1 file changed

Lines changed: 162 additions & 0 deletions

File tree

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""Post-process ingested OlmoEarth-v1-Base embedding data into the OlmoEarth Pretrain dataset."""
2+
3+
import argparse
4+
import csv
5+
import multiprocessing
6+
from datetime import datetime
7+
8+
import numpy as np
9+
import tqdm
10+
from rslearn.data_sources import Item
11+
from rslearn.dataset import Dataset, Window
12+
from rslearn.utils.mp import star_imap_unordered
13+
from rslearn.utils.raster_array import RasterArray
14+
from upath import UPath
15+
16+
from olmoearth_pretrain.data.constants import Modality, TimeSpan
17+
from olmoearth_pretrain.dataset.utils import get_modality_fname
18+
19+
from ..constants import GEOTIFF_RASTER_FORMAT, METADATA_COLUMNS
20+
from ..util import get_modality_temp_meta_fname, get_window_metadata
21+
from .multitemporal_raster import get_adjusted_projection_and_bounds
22+
23+
LAYER_NAME = "olmoearth_v1_base_embedding"
24+
25+
26+
def convert_olmoearth_v1_base_embedding(window: Window, olmoearth_path: UPath) -> None:
27+
"""Add OlmoEarth-v1-Base embedding data for this window to the OlmoEarth Pretrain dataset.
28+
29+
Args:
30+
window: the rslearn window to read data from.
31+
olmoearth_path: OlmoEarth Pretrain dataset path to write to.
32+
"""
33+
window_metadata = get_window_metadata(window)
34+
35+
if not window.is_layer_completed(LAYER_NAME):
36+
return
37+
38+
# Get start time of the source Sentinel-2 data.
39+
# We look at sentinel2_l2a_mo01 through sentinel2_l2a_mo12 to determine time range.
40+
layer_datas = window.load_layer_datas()
41+
42+
# Find the time range from any available sentinel2 layer.
43+
start_time: datetime | None = None
44+
end_time: datetime | None = None
45+
for layer_name, layer_data in layer_datas.items():
46+
if not layer_name.startswith("sentinel2_l2a_mo"):
47+
continue
48+
for item_group in layer_data.serialized_item_groups:
49+
for item_data in item_group:
50+
item = Item.deserialize(item_data)
51+
t = item.geometry.time_range[0]
52+
if start_time is None or t < start_time:
53+
start_time = t
54+
t = item.geometry.time_range[1]
55+
if end_time is None or t > end_time:
56+
end_time = t
57+
58+
if start_time is None or end_time is None:
59+
raise ValueError(
60+
f"Window {window.name} has embeddings but no sentinel2_l2a layers to determine time range"
61+
)
62+
63+
assert len(Modality.OLMOEARTH_V1_BASE_EMBEDDING.band_sets) == 1
64+
band_set = Modality.OLMOEARTH_V1_BASE_EMBEDDING.band_sets[0]
65+
adjusted_projection, adjusted_bounds = get_adjusted_projection_and_bounds(
66+
Modality.OLMOEARTH_V1_BASE_EMBEDDING,
67+
band_set,
68+
window.projection,
69+
window.bounds,
70+
)
71+
raster_dir = window.get_raster_dir(LAYER_NAME, band_set.bands)
72+
raster = GEOTIFF_RASTER_FORMAT.decode_raster(
73+
raster_dir, adjusted_projection, adjusted_bounds
74+
)
75+
# Quantize float32 embeddings ([-1, 1]) to uint8 ([0, 255]).
76+
uint8_array = np.clip(raster.array * 128 + 128, 0, 255).astype(np.uint8)
77+
raster = RasterArray(
78+
array=uint8_array, timestamps=raster.timestamps, metadata=raster.metadata
79+
)
80+
dst_fname = get_modality_fname(
81+
olmoearth_path,
82+
Modality.OLMOEARTH_V1_BASE_EMBEDDING,
83+
TimeSpan.STATIC,
84+
window_metadata,
85+
band_set.get_resolution(),
86+
"tif",
87+
)
88+
GEOTIFF_RASTER_FORMAT.encode_raster(
89+
path=dst_fname.parent,
90+
projection=adjusted_projection,
91+
bounds=adjusted_bounds,
92+
raster=raster,
93+
fname=dst_fname.name,
94+
)
95+
metadata_fname = get_modality_temp_meta_fname(
96+
olmoearth_path,
97+
Modality.OLMOEARTH_V1_BASE_EMBEDDING,
98+
TimeSpan.STATIC,
99+
window.name,
100+
)
101+
metadata_fname.parent.mkdir(parents=True, exist_ok=True)
102+
with metadata_fname.open("w") as f:
103+
writer = csv.DictWriter(f, fieldnames=METADATA_COLUMNS)
104+
writer.writeheader()
105+
writer.writerow(
106+
dict(
107+
crs=window_metadata.crs,
108+
col=window_metadata.col,
109+
row=window_metadata.row,
110+
tile_time=window_metadata.time.isoformat(),
111+
image_idx="0",
112+
start_time=start_time.isoformat(),
113+
end_time=end_time.isoformat(),
114+
)
115+
)
116+
117+
118+
if __name__ == "__main__":
119+
multiprocessing.set_start_method("forkserver")
120+
121+
parser = argparse.ArgumentParser(
122+
description="Post-process OlmoEarth Pretrain data",
123+
)
124+
parser.add_argument(
125+
"--ds_path",
126+
type=str,
127+
help="Source rslearn dataset path",
128+
required=True,
129+
)
130+
parser.add_argument(
131+
"--olmoearth_path",
132+
type=str,
133+
help="Destination OlmoEarth Pretrain dataset path",
134+
required=True,
135+
)
136+
parser.add_argument(
137+
"--workers",
138+
type=int,
139+
help="Number of workers to use",
140+
default=32,
141+
)
142+
args = parser.parse_args()
143+
144+
dataset = Dataset(UPath(args.ds_path))
145+
olmoearth_path = UPath(args.olmoearth_path)
146+
147+
jobs = []
148+
for window in dataset.load_windows(
149+
workers=args.workers, show_progress=True, groups=["res_10"]
150+
):
151+
jobs.append(
152+
dict(
153+
window=window,
154+
olmoearth_path=olmoearth_path,
155+
)
156+
)
157+
158+
p = multiprocessing.Pool(args.workers)
159+
outputs = star_imap_unordered(p, convert_olmoearth_v1_base_embedding, jobs)
160+
for _ in tqdm.tqdm(outputs, total=len(jobs)):
161+
pass
162+
p.close()

0 commit comments

Comments
 (0)