11#!/usr/bin/env python3
22
33import sys
4+
45sys .path .append ("/home/ubuntu/worldcover/model" )
56
67import os
78import tempfile
89from math import floor
910from pathlib import Path
10- import requests
1111
1212import boto3
1313import einops
1414import geopandas as gpd
15- import pandas as pd
1615import numpy
17- import pyarrow as pa
16+ import pandas as pd
1817import rasterio
18+ import requests
1919import shapely
2020import torch
2121import xarray as xr
2222from rasterio .windows import Window
23- from shapely import box
2423from torchvision .transforms import v2
2524
2625from src .datamodule import ClayDataset
@@ -152,53 +151,54 @@ def download_image(url):
152151 else :
153152 raise Exception ("Failed to download the image" )
154153
154+
155155def patches_and_windows_from_url (url , chunk_size = (PATCH_SIZE , PATCH_SIZE )):
156156 # Download the image from the URL
157157 image_data = download_image (url )
158-
158+
159159 # Open the image using rasterio from memory
160160 with rasterio .io .MemoryFile (image_data ) as memfile :
161161 with memfile .open () as src :
162162 # Read the image data and metadata
163163 img_data = src .read ()
164164 img_meta = src .profile
165165 img_crs = src .crs
166-
166+
167167 # Convert raster data and metadata into an xarray DataArray
168168 img_da = xr .DataArray (img_data , dims = ("band" , "y" , "x" ), attrs = img_meta )
169-
169+
170170 # Tile the data
171171 ds_chunked = img_da .chunk ({"y" : chunk_size [0 ], "x" : chunk_size [1 ]})
172-
172+
173173 # Get the geospatial information from the original dataset
174174 transform = img_meta ["transform" ]
175-
175+
176176 # Iterate over the chunks and compute the geospatial bounds for each chunk
177177 chunk_bounds = {}
178-
178+
179179 for x in range (ds_chunked .sizes ["x" ] // chunk_size [1 ]):
180180 for y in range (ds_chunked .sizes ["y" ] // chunk_size [0 ]):
181181 # Compute chunk coordinates
182182 x_start = x * chunk_size [1 ]
183183 y_start = y * chunk_size [0 ]
184184 x_end = min (x_start + chunk_size [1 ], ds_chunked .sizes ["x" ])
185185 y_end = min (y_start + chunk_size [0 ], ds_chunked .sizes ["y" ])
186-
186+
187187 # Compute chunk geospatial bounds
188188 lon_start , lat_start = transform * (x_start , y_start )
189189 lon_end , lat_end = transform * (x_end , y_end )
190-
190+
191191 # Store chunk bounds
192192 chunk_bounds [(x , y )] = {
193193 "lon_start" : lon_start ,
194194 "lat_start" : lat_start ,
195195 "lon_end" : lon_end ,
196196 "lat_end" : lat_end ,
197197 }
198-
198+
199199 return chunk_bounds , img_crs
200200
201-
201+
202202def make_batch (result ):
203203 pixels = []
204204 for url , win in result :
@@ -233,10 +233,10 @@ def make_batch(result):
233233 "timestep" : torch .as_tensor (data = [ds .normalize_timestamp (f"{ YEAR } -06-01" )]).to (
234234 rgb_model .device
235235 ),
236- "date" : f"{ YEAR } -06-01"
237- ,
236+ "date" : f"{ YEAR } -06-01" ,
238237 }
239238
239+
240240def get_pixels (result ):
241241 pixels = []
242242 for url , win in result :
@@ -328,89 +328,89 @@ def get_pixels(result):
328328 )
329329
330330 yoff += CHIP_SIZE
331-
332-
333331
334332 print (len (embeddings ), len (results ))
335- #embeddings = numpy.vstack(embeddings)
333+ # embeddings = numpy.vstack(embeddings)
336334 embeddings_ = embeddings [0 ]
337335 print ("Embeddings shape: " , embeddings_ .shape )
338-
336+
339337 embeddings_ = embeddings_ [:, :- 2 , :]
340-
341- print (f"Embeddings have shape { embeddings_ .shape } " ) # .mean(axis=1)
342-
338+
339+ print (f"Embeddings have shape { embeddings_ .shape } " ) # .mean(axis=1)
340+
343341 # remove date and lat/lon and reshape to disaggregated patches
344342 embeddings_patch = embeddings_ .reshape ([2 , 16 , 16 , 768 ])
345-
343+
346344 # average over the band groups
347345 embeddings_mean = embeddings_patch .mean (axis = 0 )
348-
349- print (f"Average patch embeddings have shape { embeddings_mean .shape } " )
350346
347+ print (f"Average patch embeddings have shape { embeddings_mean .shape } " )
351348
352349 if result is not None :
353350 print ("result: " , result [0 ][0 ])
354351 pix = get_pixels (result )
355352 chunk_bounds , epsg = patches_and_windows_from_url (result [0 ][0 ])
356- #print("chunk_bounds: ", chunk_bounds)
353+ # print("chunk_bounds: ", chunk_bounds)
357354 print ("chunk bounds length:" , len (chunk_bounds ))
358-
355+
359356 # Iterate through each patch
360357 for i in range (embeddings_mean .shape [0 ]):
361358 for j in range (embeddings_mean .shape [1 ]):
362359 embeddings_output_patch = embeddings_mean [i , j ]
363-
360+
364361 item_ = [
365- element for element in list (chunk_bounds .items ()) if element [0 ] == (i , j )
362+ element
363+ for element in list (chunk_bounds .items ())
364+ if element [0 ] == (i , j )
366365 ]
367366 box_ = [
368367 item_ [0 ][1 ]["lon_start" ],
369368 item_ [0 ][1 ]["lat_start" ],
370369 item_ [0 ][1 ]["lon_end" ],
371370 item_ [0 ][1 ]["lat_end" ],
372371 ]
373- #source_url = batch["source_url"]
372+ # source_url = batch["source_url"]
374373 date = batch ["date" ]
375374 date_as_timestamp = pd .to_datetime (date , format = "%Y-%m-%d" )
376375
377376 # Convert the Pandas Timestamp to the desired data type
378- #date_as_date32 = date_as_timestamp.astype('datetime64[D]')
377+ # date_as_date32 = date_as_timestamp.astype('datetime64[D]')
379378
380- #print(batch["date"])
379+ # print(batch["date"])
381380 data = {
382381 "date" : date_as_timestamp ,
383382 "embeddings" : [numpy .ascontiguousarray (embeddings_output_patch )],
384383 }
385-
384+
386385 # Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)
387386 # The box_ list is encoded as
388387 # [bottom left x, bottom left y, top right x, top right y]
389388 box_emb = shapely .geometry .box (box_ [0 ], box_ [1 ], box_ [2 ], box_ [3 ])
390389
391390 print (str (epsg )[- 4 :])
392-
391+
393392 # Create the GeoDataFrame
394- gdf = gpd .GeoDataFrame (data , geometry = [box_emb ], crs = f"EPSG:{ str (epsg )[- 4 :]} " )
395-
393+ gdf = gpd .GeoDataFrame (
394+ data , geometry = [box_emb ], crs = f"EPSG:{ str (epsg )[- 4 :]} "
395+ )
396+
396397 # Reproject to WGS84 (lon/lat coordinates)
397398 gdf = gdf .to_crs (epsg = 4326 )
398-
399-
399+
400400 with tempfile .TemporaryDirectory () as tmp :
401401 # tmp = "/home/tam/Desktop/wcctmp"
402-
402+
403403 outpath = f"{ tmp } /worldcover_patch_embeddings_{ YEAR } _{ index } _{ i } _{ j } _v{ VERSION } .gpq"
404404 print (f"Uploading embeddings to { outpath } " )
405- #print(gdf)
406-
407- gdf .to_parquet (path = outpath , compression = "ZSTD" , schema_version = "1.0.0" )
408-
405+ # print(gdf)
406+
407+ gdf .to_parquet (
408+ path = outpath , compression = "ZSTD" , schema_version = "1.0.0"
409+ )
410+
409411 s3_client = boto3 .client ("s3" )
410412 s3_client .upload_file (
411413 outpath ,
412414 BUCKET ,
413415 f"v{ VERSION } /{ YEAR } /{ os .path .basename (outpath )} " ,
414416 )
415-
416-
0 commit comments