|
17 | 17 | "outputs": [],
|
18 | 18 | "source": [
|
19 | 19 | "import sys\n",
|
| 20 | + "\n", |
20 | 21 | "sys.path.append(\"..\")"
|
21 | 22 | ]
|
22 | 23 | },
|
|
27 | 28 | "metadata": {},
|
28 | 29 | "outputs": [],
|
29 | 30 | "source": [
|
30 |
| - "import os\n", |
31 | 31 | "import glob\n",
|
32 | 32 | "import math\n",
|
33 |
| - "import boto3\n", |
34 |
| - "import yaml\n", |
| 33 | + "import os\n", |
35 | 34 | "import random\n",
|
36 |
| - "import numpy as np\n", |
37 |
| - "import pandas as pd\n", |
| 35 | + "\n", |
38 | 36 | "import geopandas as gpd\n",
|
| 37 | + "import lancedb\n", |
39 | 38 | "import matplotlib.pyplot as plt\n",
|
40 |
| - "import xarray as xr\n", |
| 39 | + "import numpy as np\n", |
| 40 | + "import pandas as pd\n", |
| 41 | + "import pystac_client\n", |
41 | 42 | "import rioxarray # noqa: F401\n",
|
| 43 | + "import shapely\n", |
42 | 44 | "import torch\n",
|
43 |
| - "import stackstac\n", |
44 |
| - "from pystac_client import Client\n", |
45 |
| - "import pystac_client\n", |
| 45 | + "import yaml\n", |
46 | 46 | "from box import Box\n",
|
47 |
| - "import lancedb\n", |
48 |
| - "from pathlib import Path\n", |
49 |
| - "import shapely\n", |
50 |
| - "from einops import rearrange\n", |
51 |
| - "from torchvision.transforms import v2\n", |
52 |
| - "from stacchip.processors.prechip import normalize_timestamp\n", |
53 |
| - "from stacchip.indexer import NoStatsChipIndexer\n", |
54 | 47 | "from stacchip.chipper import Chipper\n",
|
| 48 | + "from stacchip.indexer import NoStatsChipIndexer\n", |
| 49 | + "from stacchip.processors.prechip import normalize_timestamp\n", |
| 50 | + "from torchvision.transforms import v2\n", |
55 | 51 | "\n",
|
56 |
| - "from src.datamodule import ClayDataModule\n", |
57 |
| - "from src.model_clay import CLAYModule\n" |
| 52 | + "from src.model_clay import CLAYModule" |
58 | 53 | ]
|
59 | 54 | },
|
60 | 55 | {
|
|
65 | 60 | "outputs": [],
|
66 | 61 | "source": [
|
67 | 62 | "# Query STAC catalog for NAIP data\n",
|
68 |
| - "catalog = pystac_client.Client.open(\"https://planetarycomputer.microsoft.com/api/stac/v1\") #\"https://earth-search.aws.element84.com/v1\")\n", |
| 63 | + "catalog = pystac_client.Client.open(\n", |
| 64 | + " \"https://planetarycomputer.microsoft.com/api/stac/v1\"\n", |
| 65 | + ") # \"https://earth-search.aws.element84.com/v1\")\n", |
69 | 66 | "\n",
|
70 | 67 | "\n",
|
71 | 68 | "items = catalog.search(\n",
|
|
99 | 96 | "\n",
|
100 | 97 | " # Get first chip for the \"image\" asset key\n",
|
101 | 98 | " for chip_id in random.sample(range(0, len(chipper)), 5):\n",
|
102 |
| - " chips.append(chipper[chip_id][\"image\"])\n" |
| 99 | + " chips.append(chipper[chip_id][\"image\"])" |
103 | 100 | ]
|
104 | 101 | },
|
105 | 102 | {
|
|
128 | 125 | }
|
129 | 126 | ],
|
130 | 127 | "source": [
|
131 |
| - "fig, ax = plt.subplots(1, 1, gridspec_kw={'wspace': 0.01, 'hspace': 0.01}, squeeze=True)\n", |
| 128 | + "fig, ax = plt.subplots(1, 1, gridspec_kw={\"wspace\": 0.01, \"hspace\": 0.01}, squeeze=True)\n", |
132 | 129 | "\n",
|
133 | 130 | "chip = chips[0]\n",
|
134 | 131 | "# Visualize the data\n",
|
|
160 | 157 | " Parameters:\n",
|
161 | 158 | " stack (xarray.DataArray): The input data array containing band information.\n",
|
162 | 159 | " \"\"\"\n",
|
163 |
| - " stack.sel(band=[1, 2, 3]).plot.imshow(\n", |
164 |
| - " rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n", |
165 |
| - " )\n", |
| 160 | + " stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n", |
166 | 161 | " plt.show()\n",
|
167 |
| - " \n", |
| 162 | + "\n", |
| 163 | + "\n", |
168 | 164 | "def normalize_latlon(lat, lon):\n",
|
169 | 165 | " \"\"\"\n",
|
170 | 166 | " Normalize latitude and longitude to a range between -1 and 1.\n",
|
|
181 | 177 | "\n",
|
182 | 178 | " return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n",
|
183 | 179 | "\n",
|
| 180 | + "\n", |
184 | 181 | "def load_model(ckpt, device=\"cuda\"):\n",
|
185 | 182 | " \"\"\"\n",
|
186 | 183 | " Load a pretrained Clay model from a checkpoint.\n",
|
|
194 | 191 | " \"\"\"\n",
|
195 | 192 | " torch.set_default_device(device)\n",
|
196 | 193 | " model = CLAYModule.load_from_checkpoint(\n",
|
197 |
| - " ckpt, metadata_path=\"../configs/metadata.yaml\", shuffle=False, mask_ratio=0, model_size=\"medium\"\n", |
| 194 | + " ckpt,\n", |
| 195 | + " metadata_path=\"../configs/metadata.yaml\",\n", |
| 196 | + " shuffle=False,\n", |
| 197 | + " mask_ratio=0,\n", |
| 198 | + " model_size=\"medium\",\n", |
198 | 199 | " )\n",
|
199 | 200 | " model.eval()\n",
|
200 | 201 | " return model.to(device)\n",
|
201 | 202 | "\n",
|
| 203 | + "\n", |
202 | 204 | "def prep_datacube(stack, lat, lon, device):\n",
|
203 | 205 | " \"\"\"\n",
|
204 | 206 | " Prepare a data cube for model input.\n",
|
|
260 | 262 | " \"waves\": torch.tensor(waves, device=device),\n",
|
261 | 263 | " }\n",
|
262 | 264 | "\n",
|
| 265 | + "\n", |
263 | 266 | "def generate_embeddings(model, datacube):\n",
|
264 | 267 | " \"\"\"\n",
|
265 | 268 | " Generate embeddings from the model using the data cube.\n",
|
|
275 | 278 | " unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n",
|
276 | 279 | "\n",
|
277 | 280 | " # The first embedding is the class token, which is the\n",
|
278 |
| - " # overall single embedding. \n", |
279 |
| - " return unmsk_patch[:, 0, :].cpu().numpy()\n" |
| 281 | + " # overall single embedding.\n", |
| 282 | + " return unmsk_patch[:, 0, :].cpu().numpy()" |
280 | 283 | ]
|
281 | 284 | },
|
282 | 285 | {
|
|
337 | 340 | " # Calculate the centroid\n",
|
338 | 341 | " centroid_x = (tile.x * tile).sum() / tile.sum()\n",
|
339 | 342 | " centroid_y = (tile.y * tile).sum() / tile.sum()\n",
|
340 |
| - " \n", |
| 343 | + "\n", |
341 | 344 | " lon = centroid_x.item()\n",
|
342 | 345 | " lat = centroid_y.item()\n",
|
343 |
| - " \n", |
| 346 | + "\n", |
344 | 347 | " datacube = prep_datacube(tile, lat, lon, model.device)\n",
|
345 | 348 | " embeddings_ = generate_embeddings(model, datacube)\n",
|
346 | 349 | " embeddings.append(embeddings_)\n",
|
|
359 | 362 | " box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n",
|
360 | 363 | "\n",
|
361 | 364 | " # Create the GeoDataFrame\n",
|
362 |
| - " gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\")\n", |
| 365 | + " gdf = gpd.GeoDataFrame(\n", |
| 366 | + " data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\"\n", |
| 367 | + " )\n", |
363 | 368 | "\n",
|
364 | 369 | " # Reproject to WGS84 (lon/lat coordinates)\n",
|
365 | 370 | " gdf = gdf.to_crs(epsg=4326)\n",
|
366 | 371 | "\n",
|
367 |
| - " outpath = (\n", |
368 |
| - " f\"{outdir_embeddings}/\"\n", |
369 |
| - " f\"{fname[:-4]}.gpq\"\n", |
370 |
| - " )\n", |
| 372 | + " outpath = f\"{outdir_embeddings}/\" f\"{fname[:-4]}.gpq\"\n", |
371 | 373 | " gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n",
|
372 | 374 | " print(\n",
|
373 | 375 | " f\"Saved {len(gdf)} rows of embeddings of \"\n",
|
|
438 | 440 | "data = []\n",
|
439 | 441 | "# Dataframe to find overlaps within\n",
|
440 | 442 | "gdfs = []\n",
|
441 |
| - "idx = 0\n", |
| 443 | + "idx = 0\n", |
442 | 444 | "for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n",
|
443 | 445 | " gdf = gpd.read_parquet(emb)\n",
|
444 | 446 | " gdf[\"year\"] = gdf.date.dt.year\n",
|
|
459 | 461 | " \"box\": row[\"box\"].bounds,\n",
|
460 | 462 | " }\n",
|
461 | 463 | " )\n",
|
462 |
| - " idx += 1\n" |
| 464 | + " idx += 1" |
463 | 465 | ]
|
464 | 466 | },
|
465 | 467 | {
|
|
560 | 562 | " for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n",
|
561 | 563 | " row = df.iloc[i]\n",
|
562 | 564 | " path = row[\"path\"]\n",
|
563 |
| - " chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(band=['red', 'green', 'blue'])\n", |
564 |
| - " chip = chip.squeeze().transpose('x', 'y', 'band')\n", |
| 565 | + " chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(\n", |
| 566 | + " band=[\"red\", \"green\", \"blue\"]\n", |
| 567 | + " )\n", |
| 568 | + " chip = chip.squeeze().transpose(\"x\", \"y\", \"band\")\n", |
565 | 569 | " ax.imshow(chip)\n",
|
566 | 570 | " ax.set_title(f\"{row['idx']}\")\n",
|
567 | 571 | " ax.set_axis_off()\n",
|
|
0 commit comments