Skip to content

Commit c962128

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5f410fc commit c962128

File tree

1 file changed

+46
-46
lines changed

1 file changed

+46
-46
lines changed

scripts/worldcover/run.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
#!/usr/bin/env python3
22

33
import sys
4+
45
sys.path.append("/home/ubuntu/worldcover/model")
56

67
import os
78
import tempfile
89
from math import floor
910
from pathlib import Path
10-
import requests
1111

1212
import boto3
1313
import einops
1414
import geopandas as gpd
15-
import pandas as pd
1615
import numpy
17-
import pyarrow as pa
16+
import pandas as pd
1817
import rasterio
18+
import requests
1919
import shapely
2020
import torch
2121
import xarray as xr
2222
from rasterio.windows import Window
23-
from shapely import box
2423
from torchvision.transforms import v2
2524

2625
from src.datamodule import ClayDataset
@@ -152,53 +151,54 @@ def download_image(url):
152151
else:
153152
raise Exception("Failed to download the image")
154153

154+
155155
def patches_and_windows_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)):
156156
# Download the image from the URL
157157
image_data = download_image(url)
158-
158+
159159
# Open the image using rasterio from memory
160160
with rasterio.io.MemoryFile(image_data) as memfile:
161161
with memfile.open() as src:
162162
# Read the image data and metadata
163163
img_data = src.read()
164164
img_meta = src.profile
165165
img_crs = src.crs
166-
166+
167167
# Convert raster data and metadata into an xarray DataArray
168168
img_da = xr.DataArray(img_data, dims=("band", "y", "x"), attrs=img_meta)
169-
169+
170170
# Tile the data
171171
ds_chunked = img_da.chunk({"y": chunk_size[0], "x": chunk_size[1]})
172-
172+
173173
# Get the geospatial information from the original dataset
174174
transform = img_meta["transform"]
175-
175+
176176
# Iterate over the chunks and compute the geospatial bounds for each chunk
177177
chunk_bounds = {}
178-
178+
179179
for x in range(ds_chunked.sizes["x"] // chunk_size[1]):
180180
for y in range(ds_chunked.sizes["y"] // chunk_size[0]):
181181
# Compute chunk coordinates
182182
x_start = x * chunk_size[1]
183183
y_start = y * chunk_size[0]
184184
x_end = min(x_start + chunk_size[1], ds_chunked.sizes["x"])
185185
y_end = min(y_start + chunk_size[0], ds_chunked.sizes["y"])
186-
186+
187187
# Compute chunk geospatial bounds
188188
lon_start, lat_start = transform * (x_start, y_start)
189189
lon_end, lat_end = transform * (x_end, y_end)
190-
190+
191191
# Store chunk bounds
192192
chunk_bounds[(x, y)] = {
193193
"lon_start": lon_start,
194194
"lat_start": lat_start,
195195
"lon_end": lon_end,
196196
"lat_end": lat_end,
197197
}
198-
198+
199199
return chunk_bounds, img_crs
200200

201-
201+
202202
def make_batch(result):
203203
pixels = []
204204
for url, win in result:
@@ -233,10 +233,10 @@ def make_batch(result):
233233
"timestep": torch.as_tensor(data=[ds.normalize_timestamp(f"{YEAR}-06-01")]).to(
234234
rgb_model.device
235235
),
236-
"date": f"{YEAR}-06-01"
237-
,
236+
"date": f"{YEAR}-06-01",
238237
}
239238

239+
240240
def get_pixels(result):
241241
pixels = []
242242
for url, win in result:
@@ -328,89 +328,89 @@ def get_pixels(result):
328328
)
329329

330330
yoff += CHIP_SIZE
331-
332-
333331

334332
print(len(embeddings), len(results))
335-
#embeddings = numpy.vstack(embeddings)
333+
# embeddings = numpy.vstack(embeddings)
336334
embeddings_ = embeddings[0]
337335
print("Embeddings shape: ", embeddings_.shape)
338-
336+
339337
embeddings_ = embeddings_[:, :-2, :]
340-
341-
print(f"Embeddings have shape {embeddings_.shape}") #.mean(axis=1)
342-
338+
339+
print(f"Embeddings have shape {embeddings_.shape}") # .mean(axis=1)
340+
343341
# remove date and lat/lon and reshape to disaggregated patches
344342
embeddings_patch = embeddings_.reshape([2, 16, 16, 768])
345-
343+
346344
# average over the band groups
347345
embeddings_mean = embeddings_patch.mean(axis=0)
348-
349-
print(f"Average patch embeddings have shape {embeddings_mean.shape}")
350346

347+
print(f"Average patch embeddings have shape {embeddings_mean.shape}")
351348

352349
if result is not None:
353350
print("result: ", result[0][0])
354351
pix = get_pixels(result)
355352
chunk_bounds, epsg = patches_and_windows_from_url(result[0][0])
356-
#print("chunk_bounds: ", chunk_bounds)
353+
# print("chunk_bounds: ", chunk_bounds)
357354
print("chunk bounds length:", len(chunk_bounds))
358-
355+
359356
# Iterate through each patch
360357
for i in range(embeddings_mean.shape[0]):
361358
for j in range(embeddings_mean.shape[1]):
362359
embeddings_output_patch = embeddings_mean[i, j]
363-
360+
364361
item_ = [
365-
element for element in list(chunk_bounds.items()) if element[0] == (i, j)
362+
element
363+
for element in list(chunk_bounds.items())
364+
if element[0] == (i, j)
366365
]
367366
box_ = [
368367
item_[0][1]["lon_start"],
369368
item_[0][1]["lat_start"],
370369
item_[0][1]["lon_end"],
371370
item_[0][1]["lat_end"],
372371
]
373-
#source_url = batch["source_url"]
372+
# source_url = batch["source_url"]
374373
date = batch["date"]
375374
date_as_timestamp = pd.to_datetime(date, format="%Y-%m-%d")
376375

377376
# Convert the Pandas Timestamp to the desired data type
378-
#date_as_date32 = date_as_timestamp.astype('datetime64[D]')
377+
# date_as_date32 = date_as_timestamp.astype('datetime64[D]')
379378

380-
#print(batch["date"])
379+
# print(batch["date"])
381380
data = {
382381
"date": date_as_timestamp,
383382
"embeddings": [numpy.ascontiguousarray(embeddings_output_patch)],
384383
}
385-
384+
386385
# Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)
387386
# The box_ list is encoded as
388387
# [bottom left x, bottom left y, top right x, top right y]
389388
box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])
390389

391390
print(str(epsg)[-4:])
392-
391+
393392
# Create the GeoDataFrame
394-
gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f"EPSG:{str(epsg)[-4:]}")
395-
393+
gdf = gpd.GeoDataFrame(
394+
data, geometry=[box_emb], crs=f"EPSG:{str(epsg)[-4:]}"
395+
)
396+
396397
# Reproject to WGS84 (lon/lat coordinates)
397398
gdf = gdf.to_crs(epsg=4326)
398-
399-
399+
400400
with tempfile.TemporaryDirectory() as tmp:
401401
# tmp = "/home/tam/Desktop/wcctmp"
402-
402+
403403
outpath = f"{tmp}/worldcover_patch_embeddings_{YEAR}_{index}_{i}_{j}_v{VERSION}.gpq"
404404
print(f"Uploading embeddings to {outpath}")
405-
#print(gdf)
406-
407-
gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")
408-
405+
# print(gdf)
406+
407+
gdf.to_parquet(
408+
path=outpath, compression="ZSTD", schema_version="1.0.0"
409+
)
410+
409411
s3_client = boto3.client("s3")
410412
s3_client.upload_file(
411413
outpath,
412414
BUCKET,
413415
f"v{VERSION}/{YEAR}/{os.path.basename(outpath)}",
414416
)
415-
416-

0 commit comments

Comments
 (0)