|
8 | 8 | "outputs": [],
|
9 | 9 | "source": [
|
10 | 10 | "import sys\n",
|
| 11 | + "\n", |
11 | 12 | "sys.path.append(\"..\")"
|
12 | 13 | ]
|
13 | 14 | },
|
|
18 | 19 | "metadata": {},
|
19 | 20 | "outputs": [],
|
20 | 21 | "source": [
|
21 |
| - "import os\n", |
22 | 22 | "import glob\n",
|
23 | 23 | "import math\n",
|
24 |
| - "import boto3\n", |
25 |
| - "import yaml\n", |
| 24 | + "import os\n", |
26 | 25 | "import random\n",
|
27 |
| - "import numpy as np\n", |
28 |
| - "import pandas as pd\n", |
| 26 | + "\n", |
29 | 27 | "import geopandas as gpd\n",
|
30 |
| - "from shapely import Point\n", |
31 |
| - "from sklearn import decomposition\n", |
| 28 | + "import lancedb\n", |
32 | 29 | "import matplotlib.pyplot as plt\n",
|
33 |
| - "import xarray as xr\n", |
| 30 | + "import numpy as np\n", |
| 31 | + "import pandas as pd\n", |
34 | 32 | "import rioxarray # noqa: F401\n",
|
| 33 | + "import shapely # .geometry import Point, Polygon, box\n", |
35 | 34 | "import torch\n",
|
36 |
| - "import stackstac\n", |
37 |
| - "from pystac_client import Client\n", |
38 |
| - "import pystac_client\n", |
| 35 | + "import xarray as xr\n", |
| 36 | + "import yaml\n", |
39 | 37 | "from box import Box\n",
|
40 |
| - "import lancedb\n", |
41 |
| - "from pathlib import Path\n", |
42 |
| - "import shapely #.geometry import Point, Polygon, box\n", |
43 |
| - "from einops import rearrange\n", |
44 |
| - "from torchvision.transforms import v2\n", |
| 38 | + "from pystac_client import Client\n", |
45 | 39 | "from stacchip.processors.prechip import normalize_timestamp\n",
|
| 40 | + "from torchvision.transforms import v2\n", |
46 | 41 | "\n",
|
47 |
| - "from src.datamodule import ClayDataModule\n", |
48 |
| - "from src.model_clay_v1 import ClayMAEModule\n" |
| 42 | + "from src.model_clay_v1 import ClayMAEModule" |
49 | 43 | ]
|
50 | 44 | },
|
51 | 45 | {
|
|
56 | 50 | "outputs": [],
|
57 | 51 | "source": [
|
58 | 52 | "def plot_rgb(stack):\n",
|
59 |
| - " stack.sel(band=[1, 2, 3]).plot.imshow(\n", |
60 |
| - " rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n", |
61 |
| - " )\n", |
| 53 | + " stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n", |
62 | 54 | " plt.show()\n",
|
63 |
| - " \n", |
| 55 | + "\n", |
| 56 | + "\n", |
64 | 57 | "def normalize_latlon(lat, lon):\n",
|
65 | 58 | " lat = lat * np.pi / 180\n",
|
66 | 59 | " lon = lon * np.pi / 180\n",
|
|
130 | 123 | "\n",
|
131 | 124 | "\n",
|
132 | 125 | "def generate_embeddings(model, datacube):\n",
|
133 |
| - " #print(datacube)\n", |
| 126 | + " # print(datacube)\n", |
134 | 127 | " with torch.no_grad():\n",
|
135 | 128 | " unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n",
|
136 | 129 | "\n",
|
137 | 130 | " # The first embedding is the class token, which is the\n",
|
138 | 131 | " # overall single embedding. We extract that for PCA below.\n",
|
139 |
| - " return unmsk_patch[:, 0, :].cpu().numpy()\n" |
| 132 | + " return unmsk_patch[:, 0, :].cpu().numpy()" |
140 | 133 | ]
|
141 | 134 | },
|
142 | 135 | {
|
|
172 | 165 | "for item in items.get_all_items():\n",
|
173 | 166 | " assets = item.assets\n",
|
174 | 167 | " dataset = rioxarray.open_rasterio(item.assets[\"image\"].href).sel(band=[1, 2, 3, 4])\n",
|
175 |
| - " granule_name = item.assets[\"image\"].href.split('/')[-1]\n", |
| 168 | + " granule_name = item.assets[\"image\"].href.split(\"/\")[-1]\n", |
176 | 169 | " stackstac_datasets.append(dataset)\n",
|
177 | 170 | " granule_names.append(granule_name)\n",
|
178 |
| - " \n", |
179 |
| - " \n", |
| 171 | + "\n", |
180 | 172 | "\n",
|
181 | 173 | "# Function to tile dataset into 256x256 image chips and drop any excess border regions\n",
|
182 | 174 | "def tile_dataset(dataset, granule_name):\n",
|
|
202 | 194 | " y_end = y_start + 256\n",
|
203 | 195 | "\n",
|
204 | 196 | " # Extract the tile from the cropped dataset\n",
|
205 |
| - " tile = cropped_dataset.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))\n", |
206 |
| - " \n", |
| 197 | + " tile = cropped_dataset.isel(\n", |
| 198 | + " x=slice(x_start, x_end), y=slice(y_start, y_end)\n", |
| 199 | + " )\n", |
| 200 | + "\n", |
207 | 201 | " # Calculate the centroid\n",
|
208 | 202 | " centroid_x = (tile.x * tile).sum() / tile.sum()\n",
|
209 | 203 | " centroid_y = (tile.y * tile).sum() / tile.sum()\n",
|
210 |
| - " \n", |
| 204 | + "\n", |
211 | 205 | " # Print or use the centroid coordinates\n",
|
212 |
| - " #print(\"Centroid X:\", centroid_x.item())\n", |
213 |
| - " #print(\"Centroid Y:\", centroid_y.item())\n", |
214 |
| - " \n", |
| 206 | + " # print(\"Centroid X:\", centroid_x.item())\n", |
| 207 | + " # print(\"Centroid Y:\", centroid_y.item())\n", |
| 208 | + "\n", |
215 | 209 | " lon = centroid_x.item()\n",
|
216 | 210 | " lat = centroid_y.item()\n",
|
217 | 211 | "\n",
|
218 |
| - " tile = tile.assign_coords(band=['red','green','blue','nir'])\n", |
| 212 | + " tile = tile.assign_coords(band=[\"red\", \"green\", \"blue\", \"nir\"])\n", |
219 | 213 | " tile_save = tile\n",
|
220 | 214 | "\n",
|
221 |
| - " time_coord = xr.DataArray(['2020-01-01'], dims='time', name='time')\n", |
| 215 | + " time_coord = xr.DataArray([\"2020-01-01\"], dims=\"time\", name=\"time\")\n", |
222 | 216 | "\n",
|
223 | 217 | " # Assign the time coordinate to the DataArray\n",
|
224 | 218 | " tile = tile.expand_dims(time=[0])\n",
|
225 | 219 | " tile = tile.assign_coords(time=time_coord)\n",
|
226 | 220 | "\n",
|
227 |
| - " gsd_coord = xr.DataArray([0.6], dims='gsd', name='gsd')\n", |
| 221 | + " gsd_coord = xr.DataArray([0.6], dims=\"gsd\", name=\"gsd\")\n", |
228 | 222 | "\n",
|
229 | 223 | " # Assign the time coordinate to the DataArray\n",
|
230 | 224 | " tile = tile.expand_dims(gsd=[0])\n",
|
231 | 225 | " tile = tile.assign_coords(gsd=gsd_coord)\n",
|
232 | 226 | "\n",
|
233 | 227 | " tile_name = f\"{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n",
|
234 |
| - " #name_coord = xr.DataArray(tile_name, dims='filename', name='filename')\n", |
| 228 | + " # name_coord = xr.DataArray(tile_name, dims='filename', name='filename')\n", |
235 | 229 | "\n",
|
236 | 230 | " # Assign the time coordinate to the DataArray\n",
|
237 |
| - " #tile = tile.expand_dims(filename=[0])\n", |
238 |
| - " #tile = tile.assign_coords(filename=name_coord)\n", |
| 231 | + " # tile = tile.expand_dims(filename=[0])\n", |
| 232 | + " # tile = tile.assign_coords(filename=name_coord)\n", |
239 | 233 | "\n",
|
240 |
| - " #print(tile)\n", |
| 234 | + " # print(tile)\n", |
241 | 235 | "\n",
|
242 | 236 | " # Save the tile as a GeoTIFF\n",
|
243 | 237 | " tile_path = f\"{save_dir}/{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n",
|
244 | 238 | " tile_save.rio.to_raster(tile_path)\n",
|
245 | 239 | " tiles.append(tile)\n",
|
246 | 240 | " tile_names.append(tile_name)\n",
|
247 |
| - " \n", |
| 241 | + "\n", |
248 | 242 | " return tiles, tile_names\n",
|
249 |
| - " \n", |
| 243 | + "\n", |
250 | 244 | "\n",
|
251 | 245 | "make_tiles = False\n",
|
252 | 246 | "\n",
|
253 | 247 | "if make_tiles:\n",
|
254 | 248 | " tiles_ = []\n",
|
255 | 249 | " tile_names_ = []\n",
|
256 |
| - " \n", |
257 |
| - " \n", |
| 250 | + "\n", |
258 | 251 | " # Tile each dataset\n",
|
259 | 252 | " for dataset, granule_name in zip(stackstac_datasets, granule_names):\n",
|
260 | 253 | " tiles, tile_names = tile_dataset(dataset, granule_name)\n",
|
261 | 254 | " tiles_.append(tiles)\n",
|
262 | 255 | " tile_names_.append(tile_names)\n",
|
263 |
| - " #tiles, tile_names = tile_dataset(stackstac_datasets[0], granule_names[0])\n", |
| 256 | + " # tiles, tile_names = tile_dataset(stackstac_datasets[0], granule_names[0])\n", |
264 | 257 | " tiles__ = [tile for tile in tiles for tile_ in tiles_]\n",
|
265 | 258 | " tile_names__ = [tile for tile in tile_names for tile_ in tile_names_]\n",
|
266 | 259 | "else:\n",
|
267 | 260 | " tiles__ = []\n",
|
268 | 261 | " tile_names__ = []\n",
|
269 | 262 | " for filename in os.listdir(save_dir):\n",
|
270 |
| - " if filename.endswith(\".tif\"): \n", |
| 263 | + " if filename.endswith(\".tif\"):\n", |
271 | 264 | " tile_names__.append(filename)\n",
|
272 | 265 | " file_path = os.path.join(save_dir, filename)\n",
|
273 | 266 | " data_array = rioxarray.open_rasterio(file_path)\n",
|
|
317 | 310 | "source": [
|
318 | 311 | "model = load_model(\n",
|
319 | 312 | " # ckpt=\"s3://clay-model-ckpt/v0.5.3/mae_v0.5.3_epoch-29_val-loss-0.3073.ckpt\",\n",
|
320 |
| - " #ckpt=\"../checkpoints/v0.5.3/mae_v0.5.3_epoch-08_val-loss-0.3150.ckpt\",\n", |
| 313 | + " # ckpt=\"../checkpoints/v0.5.3/mae_v0.5.3_epoch-08_val-loss-0.3150.ckpt\",\n", |
321 | 314 | " ckpt=\"s3://clay-model-ckpt/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\",\n",
|
322 | 315 | " device=\"cuda\",\n",
|
323 | 316 | ")\n",
|
|
327 | 320 | " # Calculate the centroid\n",
|
328 | 321 | " centroid_x = (tile.x * tile).sum() / tile.sum()\n",
|
329 | 322 | " centroid_y = (tile.y * tile).sum() / tile.sum()\n",
|
330 |
| - " \n", |
| 323 | + "\n", |
331 | 324 | " # Print or use the centroid coordinates\n",
|
332 |
| - " #print(\"Centroid X:\", centroid_x.item())\n", |
333 |
| - " #print(\"Centroid Y:\", centroid_y.item())\n", |
334 |
| - " \n", |
| 325 | + " # print(\"Centroid X:\", centroid_x.item())\n", |
| 326 | + " # print(\"Centroid Y:\", centroid_y.item())\n", |
| 327 | + "\n", |
335 | 328 | " lon = centroid_x.item()\n",
|
336 | 329 | " lat = centroid_y.item()\n",
|
337 |
| - " \n", |
| 330 | + "\n", |
338 | 331 | " datacube = prep_datacube(tile, lat, lon, model.device)\n",
|
339 | 332 | " embeddings_ = generate_embeddings(model, datacube)\n",
|
340 | 333 | " embeddings.append(embeddings_)\n",
|
|
355 | 348 | " box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n",
|
356 | 349 | "\n",
|
357 | 350 | " # Create the GeoDataFrame\n",
|
358 |
| - " gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\")\n", |
| 351 | + " gdf = gpd.GeoDataFrame(\n", |
| 352 | + " data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\"\n", |
| 353 | + " )\n", |
359 | 354 | "\n",
|
360 | 355 | " # Reproject to WGS84 (lon/lat coordinates)\n",
|
361 | 356 | " gdf = gdf.to_crs(epsg=4326)\n",
|
362 | 357 | "\n",
|
363 |
| - " outpath = (\n", |
364 |
| - " f\"{outdir_embeddings}/\"\n", |
365 |
| - " f\"{fname[:-4]}.gpq\"\n", |
366 |
| - " )\n", |
| 358 | + " outpath = f\"{outdir_embeddings}/\" f\"{fname[:-4]}.gpq\"\n", |
367 | 359 | " gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n",
|
368 | 360 | " print(\n",
|
369 | 361 | " f\"Saved {len(gdf)} rows of embeddings of \"\n",
|
370 | 362 | " f\"shape {gdf.embeddings.iloc[0].shape} to {outpath}\"\n",
|
371 | 363 | " )\n",
|
372 |
| - " i=i+1" |
| 364 | + " i = i + 1" |
373 | 365 | ]
|
374 | 366 | },
|
375 | 367 | {
|
|
435 | 427 | "data = []\n",
|
436 | 428 | "# Dataframe to find overlaps within\n",
|
437 | 429 | "gdfs = []\n",
|
438 |
| - "idx = 0\n", |
| 430 | + "idx = 0\n", |
439 | 431 | "for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n",
|
440 | 432 | " gdf = gpd.read_parquet(emb)\n",
|
441 | 433 | " gdf[\"year\"] = gdf.date.dt.year\n",
|
|
456 | 448 | " \"box\": row[\"box\"].bounds,\n",
|
457 | 449 | " }\n",
|
458 | 450 | " )\n",
|
459 |
| - " idx = idx+1" |
| 451 | + " idx = idx + 1" |
460 | 452 | ]
|
461 | 453 | },
|
462 | 454 | {
|
|
572 | 564 | "source": [
|
573 | 565 | "def plot(df, cols=10):\n",
|
574 | 566 | " fig, axs = plt.subplots(1, cols, figsize=(20, 10))\n",
|
575 |
| - " i=0\n", |
| 567 | + " i = 0\n", |
576 | 568 | " for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n",
|
577 | 569 | " row = df.iloc[i]\n",
|
578 | 570 | " path = row[\"path\"]\n",
|
579 |
| - " chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(band=['red', 'green', 'blue']) #[1,2,3])\n", |
580 |
| - " #chip = tiles__[row[\"idx\"]].sel(band=['red', 'green', 'blue'])\n", |
| 571 | + " chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(\n", |
| 572 | + " band=[\"red\", \"green\", \"blue\"]\n", |
| 573 | + " ) # [1,2,3])\n", |
| 574 | + " # chip = tiles__[row[\"idx\"]].sel(band=['red', 'green', 'blue'])\n", |
581 | 575 | " width = chip.shape[-1]\n",
|
582 | 576 | " height = chip.shape[-1]\n",
|
583 | 577 | " chip = chip.squeeze()\n",
|
584 |
| - " chip = chip.transpose('x', 'y', 'band')\n", |
| 578 | + " chip = chip.transpose(\"x\", \"y\", \"band\")\n", |
585 | 579 | "\n",
|
586 | 580 | " ax.imshow(chip)\n",
|
587 | 581 | " ax.set_title(f\"{row['idx']}\")\n",
|
588 | 582 | " ax.set_axis_off()\n",
|
589 |
| - " i=i+1\n", |
| 583 | + " i = i + 1\n", |
590 | 584 | " plt.tight_layout()\n",
|
591 | 585 | " fig.savefig(\"similar.png\")"
|
592 | 586 | ]
|
|
0 commit comments