11#!/usr/bin/env python3
22
33import sys
4+
45sys .path .append ("../../" )
56
67import os
78import tempfile
89from math import floor
910from pathlib import Path
10- import requests
1111
1212import boto3
1313import einops
1414import geopandas as gpd
15- import pandas as pd
1615import numpy
17- import pyarrow as pa
16+ import pandas as pd
1817import rasterio
18+ import requests
1919import shapely
2020import torch
2121import xarray as xr
2222from rasterio .windows import Window
23- from shapely import box
2423from torchvision .transforms import v2
2524
2625from src .datamodule import ClayDataset
@@ -141,6 +140,7 @@ def tiles_and_windows(input: Window):
141140
142141 return result
143142
143+
144144def download_image (url ):
145145 # Download an image from a URL
146146 response = requests .get (url )
@@ -150,52 +150,54 @@ def download_image(url):
150150 else :
151151 raise Exception ("Failed to download the image" )
152152
153+
153154def patch_bounds_from_url (url , chunk_size = (PATCH_SIZE , PATCH_SIZE )):
154155 # Download an image from a URL
155156 image_data = download_image (url )
156-
157+
157158 # Open the image using rasterio from memory
158159 with rasterio .io .MemoryFile (image_data ) as memfile :
159160 with memfile .open () as src :
160161 # Read the image data and metadata
161162 img_data = src .read ()
162163 img_meta = src .profile
163164 img_crs = src .crs
164-
165+
165166 # Convert raster data and metadata into an xarray DataArray
166167 img_da = xr .DataArray (img_data , dims = ("band" , "y" , "x" ), attrs = img_meta )
167-
168+
168169 # Tile the data
169170 ds_chunked = img_da .chunk ({"y" : chunk_size [0 ], "x" : chunk_size [1 ]})
170-
171+
171172 # Get the geospatial information from the original dataset
172173 transform = img_meta ["transform" ]
173-
174+
174175 # Iterate over the chunks and compute the geospatial bounds for each chunk
175176 chunk_bounds = {}
176-
177+
177178 for x in range (ds_chunked .sizes ["x" ] // chunk_size [1 ]):
178179 for y in range (ds_chunked .sizes ["y" ] // chunk_size [0 ]):
179180 # Compute chunk coordinates
180181 x_start = x * chunk_size [1 ]
181182 y_start = y * chunk_size [0 ]
182183 x_end = min (x_start + chunk_size [1 ], ds_chunked .sizes ["x" ])
183184 y_end = min (y_start + chunk_size [0 ], ds_chunked .sizes ["y" ])
184-
185+
185186 # Compute chunk geospatial bounds
186187 lon_start , lat_start = transform * (x_start , y_start )
187188 lon_end , lat_end = transform * (x_end , y_end )
188-
189+
189190 # Store chunk bounds
190191 chunk_bounds [(x , y )] = {
191192 "lon_start" : lon_start ,
192193 "lat_start" : lat_start ,
193194 "lon_end" : lon_end ,
194195 "lat_end" : lat_end ,
195196 }
196-
197+
197198 return chunk_bounds , img_crs
198199
200+
199201def make_batch (result ):
200202 pixels = []
201203 for url , win in result :
@@ -230,10 +232,10 @@ def make_batch(result):
230232 "timestep" : torch .as_tensor (data = [ds .normalize_timestamp (f"{ YEAR } -06-01" )]).to (
231233 rgb_model .device
232234 ),
233- "date" : f"{ YEAR } -06-01"
234- ,
235+ "date" : f"{ YEAR } -06-01" ,
235236 }
236237
238+
237239def get_pixels (result ):
238240 pixels = []
239241 for url , win in result :
@@ -319,42 +321,41 @@ def get_pixels(result):
319321 )
320322
321323 yoff += CHIP_SIZE
322-
323-
324324
325325 print (len (embeddings ), len (results ))
326326 embeddings_ = numpy .vstack (embeddings )
327- #embeddings_ = embeddings[0]
327+ # embeddings_ = embeddings[0]
328328 print ("Embeddings shape: " , embeddings_ .shape )
329329
330330 # remove date and lat/lon
331331 embeddings_ = embeddings_ [:, :- 2 , :].mean (axis = 0 )
332-
332+
333333 print (f"Embeddings have shape { embeddings_ .shape } " )
334-
334+
335335 # reshape to disaggregated patches
336336 embeddings_patch = embeddings_ .reshape ([2 , 16 , 16 , 768 ])
337-
337+
338338 # average over the band groups
339339 embeddings_mean = embeddings_patch .mean (axis = 0 )
340-
341- print (f"Average patch embeddings have shape { embeddings_mean .shape } " )
342340
341+ print (f"Average patch embeddings have shape { embeddings_mean .shape } " )
343342
344343 if result is not None :
345344 print ("result: " , result [0 ][0 ])
346345 pix = get_pixels (result )
347346 chunk_bounds , epsg = patch_bounds_from_url (result [0 ][0 ])
348- #print("chunk_bounds: ", chunk_bounds)
347+ # print("chunk_bounds: ", chunk_bounds)
349348 print ("chunk bounds length:" , len (chunk_bounds ))
350-
349+
351350 # Iterate through each patch
352351 for i in range (embeddings_mean .shape [0 ]):
353352 for j in range (embeddings_mean .shape [1 ]):
354353 embeddings_output_patch = embeddings_mean [i , j ]
355-
354+
356355 item_ = [
357- element for element in list (chunk_bounds .items ()) if element [0 ] == (i , j )
356+ element
357+ for element in list (chunk_bounds .items ())
358+ if element [0 ] == (i , j )
358359 ]
359360 box_ = [
360361 item_ [0 ][1 ]["lon_start" ],
@@ -364,42 +365,44 @@ def get_pixels(result):
364365 ]
365366
366367 data = {
367- #"source_url": batch["source_url"][0],
368- #"date": pd.to_datetime(arg=date, format="%Y-%m-%d").astype(
368+ # "source_url": batch["source_url"][0],
369+ # "date": pd.to_datetime(arg=date, format="%Y-%m-%d").astype(
369370 # dtype="date32[day][pyarrow]"
370- #),
371- #"date": pd.to_datetime(date, format="%Y-%m-%d", dtype="date32[day][pyarrow]"),
371+ # ),
372+ # "date": pd.to_datetime(date, format="%Y-%m-%d", dtype="date32[day][pyarrow]"),
372373 "date" : pd .to_datetime (batch ["date" ], format = "%Y-%m-%d" ),
373374 "embeddings" : [numpy .ascontiguousarray (embeddings_output_patch )],
374375 }
375-
376+
376377 # Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)
377378 # The box_ list is encoded as
378379 # [bottom left x, bottom left y, top right x, top right y]
379380 box_emb = shapely .geometry .box (box_ [0 ], box_ [1 ], box_ [2 ], box_ [3 ])
380381
381382 print (str (epsg )[- 4 :])
382-
383+
383384 # Create the GeoDataFrame
384- gdf = gpd .GeoDataFrame (data , geometry = [box_emb ], crs = f"EPSG:{ str (epsg )[- 4 :]} " )
385-
385+ gdf = gpd .GeoDataFrame (
386+ data , geometry = [box_emb ], crs = f"EPSG:{ str (epsg )[- 4 :]} "
387+ )
388+
386389 # Reproject to WGS84 (lon/lat coordinates)
387390 gdf = gdf .to_crs (epsg = 4326 )
388-
391+
389392 with tempfile .TemporaryDirectory () as tmp :
390393 # tmp = "/home/tam/Desktop/wcctmp"
391-
394+
392395 outpath = f"{ tmp } /worldcover_patch_embeddings_{ YEAR } _{ index } _{ i } _{ j } _v{ VERSION } .gpq"
393396 print (f"Uploading embeddings to { outpath } " )
394- #print(gdf)
395-
396- gdf .to_parquet (path = outpath , compression = "ZSTD" , schema_version = "1.0.0" )
397-
397+ # print(gdf)
398+
399+ gdf .to_parquet (
400+ path = outpath , compression = "ZSTD" , schema_version = "1.0.0"
401+ )
402+
398403 s3_client = boto3 .client ("s3" )
399404 s3_client .upload_file (
400405 outpath ,
401406 BUCKET ,
402407 f"v{ VERSION } /{ YEAR } /{ os .path .basename (outpath )} " ,
403408 )
404-
405-
0 commit comments