|
17 | 17 | "outputs": [],
|
18 | 18 | "source": [
|
19 | 19 | "import sys\n",
|
| 20 | + "\n", |
20 | 21 | "sys.path.append(\"..\")"
|
21 | 22 | ]
|
22 | 23 | },
|
|
27 | 28 | "metadata": {},
|
28 | 29 | "outputs": [],
|
29 | 30 | "source": [
|
30 |
| - "import os\n", |
31 | 31 | "import glob\n",
|
32 | 32 | "import math\n",
|
33 |
| - "import boto3\n", |
34 |
| - "import yaml\n", |
| 33 | + "import os\n", |
35 | 34 | "import random\n",
|
36 |
| - "import numpy as np\n", |
37 |
| - "import pandas as pd\n", |
| 35 | + "\n", |
38 | 36 | "import geopandas as gpd\n",
|
| 37 | + "import lancedb\n", |
39 | 38 | "import matplotlib.pyplot as plt\n",
|
40 |
| - "import xarray as xr\n", |
| 39 | + "import numpy as np\n", |
| 40 | + "import pandas as pd\n", |
41 | 41 | "import rioxarray # noqa: F401\n",
|
| 42 | + "import shapely\n", |
42 | 43 | "import torch\n",
|
43 |
| - "import stackstac\n", |
44 |
| - "from pystac_client import Client\n", |
45 |
| - "import pystac_client\n", |
| 44 | + "import xarray as xr\n", |
| 45 | + "import yaml\n", |
46 | 46 | "from box import Box\n",
|
47 |
| - "import lancedb\n", |
48 |
| - "from pathlib import Path\n", |
49 |
| - "import shapely\n", |
50 |
| - "from einops import rearrange\n", |
51 |
| - "from torchvision.transforms import v2\n", |
| 47 | + "from pystac_client import Client\n", |
52 | 48 | "from stacchip.processors.prechip import normalize_timestamp\n",
|
| 49 | + "from torchvision.transforms import v2\n", |
53 | 50 | "\n",
|
54 |
| - "from src.datamodule import ClayDataModule\n", |
55 |
| - "from src.model_clay_v1 import ClayMAEModule\n" |
| 51 | + "from src.model_clay_v1 import ClayMAEModule" |
56 | 52 | ]
|
57 | 53 | },
|
58 | 54 | {
|
|
69 | 65 | " Parameters:\n",
|
70 | 66 | " stack (xarray.DataArray): The input data array containing band information.\n",
|
71 | 67 | " \"\"\"\n",
|
72 |
| - " stack.sel(band=[1, 2, 3]).plot.imshow(\n", |
73 |
| - " rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n", |
74 |
| - " )\n", |
| 68 | + " stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n", |
75 | 69 | " plt.show()\n",
|
76 |
| - " \n", |
| 70 | + "\n", |
| 71 | + "\n", |
77 | 72 | "def normalize_latlon(lat, lon):\n",
|
78 | 73 | " \"\"\"\n",
|
79 | 74 | " Normalize latitude and longitude to a range between -1 and 1.\n",
|
|
90 | 85 | "\n",
|
91 | 86 | " return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n",
|
92 | 87 | "\n",
|
| 88 | + "\n", |
93 | 89 | "def load_model(ckpt, device=\"cuda\"):\n",
|
94 | 90 | " \"\"\"\n",
|
95 | 91 | " Load a pretrained Clay model from a checkpoint.\n",
|
|
108 | 104 | " model.eval()\n",
|
109 | 105 | " return model.to(device)\n",
|
110 | 106 | "\n",
|
| 107 | + "\n", |
111 | 108 | "def prep_datacube(stack, lat, lon, device):\n",
|
112 | 109 | " \"\"\"\n",
|
113 | 110 | " Prepare a data cube for model input.\n",
|
|
169 | 166 | " \"waves\": torch.tensor(waves, device=device),\n",
|
170 | 167 | " }\n",
|
171 | 168 | "\n",
|
| 169 | + "\n", |
172 | 170 | "def generate_embeddings(model, datacube):\n",
|
173 | 171 | " \"\"\"\n",
|
174 | 172 | " Generate embeddings from the model using the data cube.\n",
|
|
184 | 182 | " unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n",
|
185 | 183 | "\n",
|
186 | 184 | " # The first embedding is the class token, which is the\n",
|
187 |
| - " # overall single embedding. \n", |
188 |
| - " return unmsk_patch[:, 0, :].cpu().numpy()\n" |
| 185 | + " # overall single embedding.\n", |
| 186 | + " return unmsk_patch[:, 0, :].cpu().numpy()" |
189 | 187 | ]
|
190 | 188 | },
|
191 | 189 | {
|
|
227 | 225 | " y_end = y_start + 256\n",
|
228 | 226 | "\n",
|
229 | 227 | " # Extract the tile from the cropped dataset\n",
|
230 |
| - " tile = cropped_dataset.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))\n", |
231 |
| - " \n", |
| 228 | + " tile = cropped_dataset.isel(\n", |
| 229 | + " x=slice(x_start, x_end), y=slice(y_start, y_end)\n", |
| 230 | + " )\n", |
| 231 | + "\n", |
232 | 232 | " # Calculate the centroid\n",
|
233 | 233 | " centroid_x = (tile.x * tile).sum() / tile.sum()\n",
|
234 | 234 | " centroid_y = (tile.y * tile).sum() / tile.sum()\n",
|
235 |
| - " \n", |
| 235 | + "\n", |
236 | 236 | " lon = centroid_x.item()\n",
|
237 | 237 | " lat = centroid_y.item()\n",
|
238 | 238 | "\n",
|
239 |
| - " tile = tile.assign_coords(band=['red','green','blue','nir'])\n", |
| 239 | + " tile = tile.assign_coords(band=[\"red\", \"green\", \"blue\", \"nir\"])\n", |
240 | 240 | " tile_save = tile\n",
|
241 | 241 | "\n",
|
242 |
| - " time_coord = xr.DataArray(['2020-01-01'], dims='time', name='time')\n", |
| 242 | + " time_coord = xr.DataArray([\"2020-01-01\"], dims=\"time\", name=\"time\")\n", |
243 | 243 | " tile = tile.expand_dims(time=[0])\n",
|
244 | 244 | " tile = tile.assign_coords(time=time_coord)\n",
|
245 | 245 | "\n",
|
246 |
| - " gsd_coord = xr.DataArray([0.6], dims='gsd', name='gsd')\n", |
| 246 | + " gsd_coord = xr.DataArray([0.6], dims=\"gsd\", name=\"gsd\")\n", |
247 | 247 | " tile = tile.expand_dims(gsd=[0])\n",
|
248 | 248 | " tile = tile.assign_coords(gsd=gsd_coord)\n",
|
249 | 249 | "\n",
|
|
253 | 253 | " tile_save.rio.to_raster(tile_path)\n",
|
254 | 254 | " tiles.append(tile)\n",
|
255 | 255 | " tile_names.append(tile_name)\n",
|
256 |
| - " \n", |
| 256 | + "\n", |
257 | 257 | " return tiles, tile_names"
|
258 | 258 | ]
|
259 | 259 | },
|
|
291 | 291 | "for item in items.get_all_items():\n",
|
292 | 292 | " assets = item.assets\n",
|
293 | 293 | " dataset = rioxarray.open_rasterio(item.assets[\"image\"].href).sel(band=[1, 2, 3, 4])\n",
|
294 |
| - " granule_name = item.assets[\"image\"].href.split('/')[-1]\n", |
| 294 | + " granule_name = item.assets[\"image\"].href.split(\"/\")[-1]\n", |
295 | 295 | " stackstac_datasets.append(dataset)\n",
|
296 | 296 | " granule_names.append(granule_name)"
|
297 | 297 | ]
|
|
321 | 321 | " tiles__ = []\n",
|
322 | 322 | " tile_names__ = []\n",
|
323 | 323 | " for filename in os.listdir(save_dir):\n",
|
324 |
| - " if filename.endswith(\".tif\"): \n", |
| 324 | + " if filename.endswith(\".tif\"):\n", |
325 | 325 | " tile_names__.append(filename)\n",
|
326 | 326 | " file_path = os.path.join(save_dir, filename)\n",
|
327 | 327 | " data_array = rioxarray.open_rasterio(file_path)\n",
|
|
391 | 391 | " # Calculate the centroid\n",
|
392 | 392 | " centroid_x = (tile.x * tile).sum() / tile.sum()\n",
|
393 | 393 | " centroid_y = (tile.y * tile).sum() / tile.sum()\n",
|
394 |
| - " \n", |
| 394 | + "\n", |
395 | 395 | " lon = centroid_x.item()\n",
|
396 | 396 | " lat = centroid_y.item()\n",
|
397 |
| - " \n", |
| 397 | + "\n", |
398 | 398 | " datacube = prep_datacube(tile, lat, lon, model.device)\n",
|
399 | 399 | " embeddings_ = generate_embeddings(model, datacube)\n",
|
400 | 400 | " embeddings.append(embeddings_)\n",
|
|
413 | 413 | " box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n",
|
414 | 414 | "\n",
|
415 | 415 | " # Create the GeoDataFrame\n",
|
416 |
| - " gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\")\n", |
| 416 | + " gdf = gpd.GeoDataFrame(\n", |
| 417 | + " data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\"\n", |
| 418 | + " )\n", |
417 | 419 | "\n",
|
418 | 420 | " # Reproject to WGS84 (lon/lat coordinates)\n",
|
419 | 421 | " gdf = gdf.to_crs(epsg=4326)\n",
|
420 | 422 | "\n",
|
421 |
| - " outpath = (\n", |
422 |
| - " f\"{outdir_embeddings}/\"\n", |
423 |
| - " f\"{fname[:-4]}.gpq\"\n", |
424 |
| - " )\n", |
| 423 | + " outpath = f\"{outdir_embeddings}/\" f\"{fname[:-4]}.gpq\"\n", |
425 | 424 | " gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n",
|
426 | 425 | " print(\n",
|
427 | 426 | " f\"Saved {len(gdf)} rows of embeddings of \"\n",
|
|
494 | 493 | "data = []\n",
|
495 | 494 | "# Dataframe to find overlaps within\n",
|
496 | 495 | "gdfs = []\n",
|
497 |
| - "idx = 0\n", |
| 496 | + "idx = 0\n", |
498 | 497 | "for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n",
|
499 | 498 | " gdf = gpd.read_parquet(emb)\n",
|
500 | 499 | " gdf[\"year\"] = gdf.date.dt.year\n",
|
|
515 | 514 | " \"box\": row[\"box\"].bounds,\n",
|
516 | 515 | " }\n",
|
517 | 516 | " )\n",
|
518 |
| - " idx += 1\n" |
| 517 | + " idx += 1" |
519 | 518 | ]
|
520 | 519 | },
|
521 | 520 | {
|
|
637 | 636 | " for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n",
|
638 | 637 | " row = df.iloc[i]\n",
|
639 | 638 | " path = row[\"path\"]\n",
|
640 |
| - " chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(band=['red', 'green', 'blue'])\n", |
641 |
| - " chip = chip.squeeze().transpose('x', 'y', 'band')\n", |
| 639 | + " chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(\n", |
| 640 | + " band=[\"red\", \"green\", \"blue\"]\n", |
| 641 | + " )\n", |
| 642 | + " chip = chip.squeeze().transpose(\"x\", \"y\", \"band\")\n", |
642 | 643 | " ax.imshow(chip)\n",
|
643 | 644 | " ax.set_title(f\"{row['idx']}\")\n",
|
644 | 645 | " ax.set_axis_off()\n",
|
|
0 commit comments