-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathgenerate_embeddings.py
More file actions
310 lines (254 loc) · 11.6 KB
/
generate_embeddings.py
File metadata and controls
310 lines (254 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
#!/usr/bin/env python3
"""
Generate embedding datasets from MajorTOM parquet files.
This script replicates the functionality of the
05-Generate-Major-TOM-Embeddings.ipynb notebook. It loads a chosen model,
wraps it with MajorTOM_Embedder, processes each row group in the input
parquet(s), and writes a GeoParquet file containing the embeddings and
spatial metadata.
Example:
python generate_embeddings.py \
--model_name dinov2 \
--meta_path /data384/datasets/Core-S2L2A/metadata.parquet \
--parquet_input /data384/datasets/Core-S2L2A/images/part_00001.parquet \
--output_path /data384/datasets/embeddings_test/dinov2_test.parquet \
--fragment_size 384
"""
import argparse
import hashlib
import os
import sys
import cv2
import geopandas as gpd
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch
from fsspec.parquet import open_parquet_file
from pyproj import CRS, Transformer
from shapely.geometry import box
from shapely.ops import transform as shapely_transform
# Ensure project root is on path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from MajorTOM.embedder.MajorTOM_Embedder import MajorTOM_Embedder
from models.clay_model import ClayModel
from models.dinov2_model import DINOv2Model
from models.farslip_model import FarSLIPModel
from models.satclip_model import SatCLIPModel
from models.siglip_model import SigLIPModel
from models.olmoearth_model import OlmoEarthModel
from models.load_config import load_config
MODEL_MAP = {
"dinov2": DINOv2Model,
"siglip": SigLIPModel,
"farslip": FarSLIPModel,
"satclip": SatCLIPModel,
"clay": ClayModel,
"olmoearth": OlmoEarthModel,
}
def get_model_kwargs(model_name, device):
"""Build model kwargs from config.yaml or defaults."""
kwargs = {"device": device}
config = load_config()
if config and model_name in config:
model_cfg = config[model_name]
if "ckpt_path" in model_cfg:
kwargs["ckpt_path"] = model_cfg["ckpt_path"]
if "model_name" in model_cfg:
kwargs["model_name"] = model_cfg["model_name"]
if "tokenizer_path" in model_cfg:
kwargs["tokenizer_path"] = model_cfg["tokenizer_path"]
if "model_size" in model_cfg:
kwargs["model_size"] = model_cfg["model_size"]
return kwargs
def get_parquet_files(parquet_input):
"""Return a list of parquet file paths from a file or directory."""
if os.path.isfile(parquet_input):
return [parquet_input]
elif os.path.isdir(parquet_input):
files = []
for fname in sorted(os.listdir(parquet_input)):
if fname.endswith(".parquet"):
files.append(os.path.join(parquet_input, fname))
return files
else:
raise ValueError(f"parquet_input must be a file or directory: {parquet_input}")
def resolve_meta_url(meta_path, parquet_file_path):
"""
Resolve metadata path. If meta_path is relative and parquet_file_path
points to a local directory, try to locate metadata relative to the
parquet directory.
"""
if os.path.isabs(meta_path) or os.path.exists(meta_path):
return meta_path
# If parquet is local, try resolving relative to its parent
if os.path.isfile(parquet_file_path):
base_dir = os.path.dirname(os.path.dirname(parquet_file_path))
candidate = os.path.join(base_dir, meta_path)
if os.path.exists(candidate):
return candidate
return meta_path
def _embed_single_fragment(embedder, row, row_meta, device, fragment_size, img=None, footprint=None, crs=None):
"""
Embed a pre-cropped image as a single fragment (no tiling).
Reads the image bands (or uses pre-read ones), optionally resizes to
fragment_size, encodes the whole image with the model, and returns a
GeoDataFrame with a single row.
"""
if img is None:
img, footprint, crs = embedder._read_image(row)
h, w, c = img.shape
# Resize to target fragment_size if image is not exactly fragment_size
if h != fragment_size or w != fragment_size:
img_np = img.numpy() if torch.is_tensor(img) else np.array(img)
img_resized = cv2.resize(img_np, (fragment_size, fragment_size), interpolation=cv2.INTER_NEAREST)
img = torch.from_numpy(img_resized)
else:
img = img if torch.is_tensor(img) else torch.from_numpy(np.array(img))
# Encode whole image: (H,W,C) -> (1,C,H,W)
img_tensor = img.permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
embedding = embedder.embedder(img_tensor).cpu().numpy()[0]
pixel_bbox = [0, 0, fragment_size, fragment_size]
utm_footprint = footprint
transformer = Transformer.from_crs(crs, CRS.from_epsg(4326), always_xy=True)
geometry = shapely_transform(transformer.transform, utm_footprint)
centre_lon, centre_lat = geometry.centroid.coords[0]
combined = f"{geometry}_{row_meta.timestamp.item()}_{row_meta.product_id.item()}_{embedding}"
unique_id = hashlib.sha256(combined.encode()).hexdigest()
row_dict = {
'unique_id': unique_id,
'embedding': embedding,
'timestamp': row_meta.timestamp.item(),
'product_id': row_meta.product_id.item(),
'grid_cell': row_meta.grid_cell.item(),
'grid_row_u': row_meta.grid_row_u.item(),
'grid_col_r': row_meta.grid_col_r.item(),
'geometry': geometry,
'centre_lat': centre_lat,
'centre_lon': centre_lon,
'utm_footprint': utm_footprint.wkt,
'utm_crs': crs.to_string(),
'pixel_bbox': pixel_bbox,
'parquet_row': row_meta.parquet_row.item() if 'parquet_row' in row_meta.columns else None,
'parquet_url': row_meta.parquet_url.item() if 'parquet_url' in row_meta.columns else None,
}
gdf = gpd.GeoDataFrame([row_dict])
column_types = {
'grid_row_u': 'int16',
'grid_col_r': 'int16',
'centre_lat': 'float32',
'centre_lon': 'float32',
}
return gdf.astype(column_types)
def generate_embeddings(model_name, meta_path, parquet_input, output_path, device=None, max_row_groups=None, fragment_size=None):
"""Main embedding generation logic."""
if model_name not in MODEL_MAP:
raise ValueError(f"Unknown model: {model_name}. Choose from {list(MODEL_MAP.keys())}")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"Loading {model_name} model...")
# Load model (no embedding file needed)
model_cls = MODEL_MAP[model_name]
model_kwargs = get_model_kwargs(model_name, device)
model = model_cls(**model_kwargs)
print(f"Model bands: {model.bands}")
print(f"Model input size: {model.size}")
# Wrap with MajorTOM_Embedder
embedder = MajorTOM_Embedder(model)
embedder.to(device)
# Override fragment_size if specified (e.g. for pre-cropped 384x384 imagery)
if fragment_size is not None:
embedder.frag_params['fragment_size'] = fragment_size
print(f"Override fragment_size to {fragment_size}")
use_single_fragment = fragment_size is not None
parquet_files = get_parquet_files(parquet_input)
print(f"Found {len(parquet_files)} parquet file(s) to process.")
embed_df = None
for pf_path in parquet_files:
print(f"\nProcessing {pf_path} ...")
resolved_meta = resolve_meta_url(meta_path, pf_path)
print(f"Loading metadata from {resolved_meta} ...")
meta_df = pd.read_parquet(resolved_meta)
bands = embedder.bands()
columns = list(bands) + ["product_id", "grid_cell", "timestamp"]
# Open parquet file
if os.path.isfile(pf_path):
# Local file
pf = pq.ParquetFile(pf_path)
else:
# Remote file via fsspec
f = open_parquet_file(pf_path, columns=columns)
pf = pq.ParquetFile(f)
num_row_groups = pf.num_row_groups if max_row_groups is None else min(pf.num_row_groups, max_row_groups)
for row_idx in range(num_row_groups):
row = pf.read_row_group(row_idx, columns=columns)
grid_cell = row["grid_cell"][0].as_py()
product_id = row["product_id"][0].as_py()
row_meta = meta_df[
(meta_df["grid_cell"] == grid_cell) & (meta_df["product_id"] == product_id)
].head(1)
if row_meta.empty:
print(f" ⚠️ Metadata not found for {product_id} / {grid_cell}, skipping.")
continue
if use_single_fragment:
# Peek at image size to decide whether to tile or treat as a single fragment
img, footprint, crs = embedder._read_image(row)
h, w = img.shape[:2]
if h <= fragment_size and w <= fragment_size:
embed_dict = _embed_single_fragment(embedder, row, row_meta, device, fragment_size, img=img, footprint=footprint, crs=crs)
else:
embed_dict = embedder(row, row_meta, device=device)
else:
embed_dict = embedder(row, row_meta, device=device)
if embed_df is None:
embed_df = embed_dict
else:
embed_df = pd.concat([embed_df, embed_dict], ignore_index=True)
if (row_idx + 1) % 10 == 0 or row_idx == num_row_groups - 1:
print(f" Processed {row_idx + 1}/{num_row_groups} row groups, total embeddings: {len(embed_df)}")
if embed_df is None or embed_df.empty:
print("No embeddings were generated.")
return
embed_df = embed_df.reset_index(drop=True)
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
embed_df.to_parquet(output_path)
print(f"\n✅ Saved {len(embed_df)} embeddings to {output_path}")
# Sanity check
sanity = pd.read_parquet(output_path)
print("Sanity check columns:", sanity.columns.tolist())
print(sanity.head())
def main():
parser = argparse.ArgumentParser(description="Generate MajorTOM embeddings")
parser.add_argument("--model_name", type=str, required=True,
choices=["dinov2", "siglip", "farslip", "satclip", "clay", "olmoearth"],
help="Model to use for embedding generation")
parser.add_argument("--meta_path", type=str, required=True,
help="Path to metadata.parquet")
parser.add_argument("--parquet_input", type=str, required=True,
help="Path to a parquet file or directory containing parquet files")
parser.add_argument("--output_path", type=str, required=True,
help="Output GeoParquet file path")
parser.add_argument("--device", type=str, default=None,
help="Device to run on (cuda/cpu). Auto-detected if omitted.")
parser.add_argument("--max_row_groups", type=int, default=None,
help="Maximum number of row groups to process per parquet file (default: all).")
parser.add_argument("--fragment_size", type=int, default=None,
help=(
"Override the default fragment size (model input size). "
"Useful for pre-cropped imagery (e.g. 384x384) where each image "
"should produce a single embedding instead of multiple fragments."
))
args = parser.parse_args()
generate_embeddings(
model_name=args.model_name,
meta_path=args.meta_path,
parquet_input=args.parquet_input,
output_path=args.output_path,
device=args.device,
max_row_groups=args.max_row_groups,
fragment_size=args.fragment_size,
)
if __name__ == "__main__":
main()