Skip to content

Commit 49bf34a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent be3cc33 commit 49bf34a

File tree

3 files changed

+107
-83
lines changed

3 files changed

+107
-83
lines changed

scripts/worldcover/embeddings_db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
db.table_names()
3232

3333
# Drop existing table if exists
34-
#db.drop_table("worldcover-2020-v001")
34+
# db.drop_table("worldcover-2020-v001")
3535

3636
# Create embeddings table and insert the vector data
3737
tbl = db.create_table("worldcover-2020-v001", data=data, mode="overwrite")
@@ -55,4 +55,4 @@ def plot(df, cols=10):
5555
# Select a vector by index, and search 10 similar pairs, and plot
5656
v = tbl.to_pandas()["vector"].values[5]
5757
result = tbl.search(query=v).limit(5).to_pandas()
58-
plot(result, 5)
58+
plot(result, 5)

scripts/worldcover/run.py

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
#!/usr/bin/env python3
22

33
import sys
4+
45
sys.path.append("../../")
56

67
import os
78
import tempfile
89
from math import floor
910
from pathlib import Path
10-
import requests
1111

1212
import boto3
1313
import einops
1414
import geopandas as gpd
15-
import pandas as pd
1615
import numpy
17-
import pyarrow as pa
16+
import pandas as pd
1817
import rasterio
18+
import requests
1919
import shapely
2020
import torch
2121
import xarray as xr
2222
from rasterio.windows import Window
23-
from shapely import box
2423
from torchvision.transforms import v2
2524

2625
from src.datamodule import ClayDataset
@@ -141,6 +140,7 @@ def tiles_and_windows(input: Window):
141140

142141
return result
143142

143+
144144
def download_image(url):
145145
# Download an image from a URL
146146
response = requests.get(url)
@@ -150,52 +150,54 @@ def download_image(url):
150150
else:
151151
raise Exception("Failed to download the image")
152152

153+
153154
def patch_bounds_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)):
154155
# Download an image from a URL
155156
image_data = download_image(url)
156-
157+
157158
# Open the image using rasterio from memory
158159
with rasterio.io.MemoryFile(image_data) as memfile:
159160
with memfile.open() as src:
160161
# Read the image data and metadata
161162
img_data = src.read()
162163
img_meta = src.profile
163164
img_crs = src.crs
164-
165+
165166
# Convert raster data and metadata into an xarray DataArray
166167
img_da = xr.DataArray(img_data, dims=("band", "y", "x"), attrs=img_meta)
167-
168+
168169
# Tile the data
169170
ds_chunked = img_da.chunk({"y": chunk_size[0], "x": chunk_size[1]})
170-
171+
171172
# Get the geospatial information from the original dataset
172173
transform = img_meta["transform"]
173-
174+
174175
# Iterate over the chunks and compute the geospatial bounds for each chunk
175176
chunk_bounds = {}
176-
177+
177178
for x in range(ds_chunked.sizes["x"] // chunk_size[1]):
178179
for y in range(ds_chunked.sizes["y"] // chunk_size[0]):
179180
# Compute chunk coordinates
180181
x_start = x * chunk_size[1]
181182
y_start = y * chunk_size[0]
182183
x_end = min(x_start + chunk_size[1], ds_chunked.sizes["x"])
183184
y_end = min(y_start + chunk_size[0], ds_chunked.sizes["y"])
184-
185+
185186
# Compute chunk geospatial bounds
186187
lon_start, lat_start = transform * (x_start, y_start)
187188
lon_end, lat_end = transform * (x_end, y_end)
188-
189+
189190
# Store chunk bounds
190191
chunk_bounds[(x, y)] = {
191192
"lon_start": lon_start,
192193
"lat_start": lat_start,
193194
"lon_end": lon_end,
194195
"lat_end": lat_end,
195196
}
196-
197+
197198
return chunk_bounds, img_crs
198199

200+
199201
def make_batch(result):
200202
pixels = []
201203
for url, win in result:
@@ -230,10 +232,10 @@ def make_batch(result):
230232
"timestep": torch.as_tensor(data=[ds.normalize_timestamp(f"{YEAR}-06-01")]).to(
231233
rgb_model.device
232234
),
233-
"date": f"{YEAR}-06-01"
234-
,
235+
"date": f"{YEAR}-06-01",
235236
}
236237

238+
237239
def get_pixels(result):
238240
pixels = []
239241
for url, win in result:
@@ -319,42 +321,41 @@ def get_pixels(result):
319321
)
320322

321323
yoff += CHIP_SIZE
322-
323-
324324

325325
print(len(embeddings), len(results))
326326
embeddings_ = numpy.vstack(embeddings)
327-
#embeddings_ = embeddings[0]
327+
# embeddings_ = embeddings[0]
328328
print("Embeddings shape: ", embeddings_.shape)
329329

330330
# remove date and lat/lon
331331
embeddings_ = embeddings_[:, :-2, :].mean(axis=0)
332-
332+
333333
print(f"Embeddings have shape {embeddings_.shape}")
334-
334+
335335
# reshape to disaggregated patches
336336
embeddings_patch = embeddings_.reshape([2, 16, 16, 768])
337-
337+
338338
# average over the band groups
339339
embeddings_mean = embeddings_patch.mean(axis=0)
340-
341-
print(f"Average patch embeddings have shape {embeddings_mean.shape}")
342340

341+
print(f"Average patch embeddings have shape {embeddings_mean.shape}")
343342

344343
if result is not None:
345344
print("result: ", result[0][0])
346345
pix = get_pixels(result)
347346
chunk_bounds, epsg = patch_bounds_from_url(result[0][0])
348-
#print("chunk_bounds: ", chunk_bounds)
347+
# print("chunk_bounds: ", chunk_bounds)
349348
print("chunk bounds length:", len(chunk_bounds))
350-
349+
351350
# Iterate through each patch
352351
for i in range(embeddings_mean.shape[0]):
353352
for j in range(embeddings_mean.shape[1]):
354353
embeddings_output_patch = embeddings_mean[i, j]
355-
354+
356355
item_ = [
357-
element for element in list(chunk_bounds.items()) if element[0] == (i, j)
356+
element
357+
for element in list(chunk_bounds.items())
358+
if element[0] == (i, j)
358359
]
359360
box_ = [
360361
item_[0][1]["lon_start"],
@@ -364,42 +365,44 @@ def get_pixels(result):
364365
]
365366

366367
data = {
367-
#"source_url": batch["source_url"][0],
368-
#"date": pd.to_datetime(arg=date, format="%Y-%m-%d").astype(
368+
# "source_url": batch["source_url"][0],
369+
# "date": pd.to_datetime(arg=date, format="%Y-%m-%d").astype(
369370
# dtype="date32[day][pyarrow]"
370-
#),
371-
#"date": pd.to_datetime(date, format="%Y-%m-%d", dtype="date32[day][pyarrow]"),
371+
# ),
372+
# "date": pd.to_datetime(date, format="%Y-%m-%d", dtype="date32[day][pyarrow]"),
372373
"date": pd.to_datetime(batch["date"], format="%Y-%m-%d"),
373374
"embeddings": [numpy.ascontiguousarray(embeddings_output_patch)],
374375
}
375-
376+
376377
# Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)
377378
# The box_ list is encoded as
378379
# [bottom left x, bottom left y, top right x, top right y]
379380
box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])
380381

381382
print(str(epsg)[-4:])
382-
383+
383384
# Create the GeoDataFrame
384-
gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f"EPSG:{str(epsg)[-4:]}")
385-
385+
gdf = gpd.GeoDataFrame(
386+
data, geometry=[box_emb], crs=f"EPSG:{str(epsg)[-4:]}"
387+
)
388+
386389
# Reproject to WGS84 (lon/lat coordinates)
387390
gdf = gdf.to_crs(epsg=4326)
388-
391+
389392
with tempfile.TemporaryDirectory() as tmp:
390393
# tmp = "/home/tam/Desktop/wcctmp"
391-
394+
392395
outpath = f"{tmp}/worldcover_patch_embeddings_{YEAR}_{index}_{i}_{j}_v{VERSION}.gpq"
393396
print(f"Uploading embeddings to {outpath}")
394-
#print(gdf)
395-
396-
gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")
397-
397+
# print(gdf)
398+
399+
gdf.to_parquet(
400+
path=outpath, compression="ZSTD", schema_version="1.0.0"
401+
)
402+
398403
s3_client = boto3.client("s3")
399404
s3_client.upload_file(
400405
outpath,
401406
BUCKET,
402407
f"v{VERSION}/{YEAR}/{os.path.basename(outpath)}",
403408
)
404-
405-

0 commit comments

Comments
 (0)