Skip to content

Commit 82a68fb

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0c72b7e commit 82a68fb

File tree

2 files changed

+87
-81
lines changed

2 files changed

+87
-81
lines changed

nbs/v1-inference-simsearch-naip-stacchip.ipynb

+43-39
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"outputs": [],
1818
"source": [
1919
"import sys\n",
20+
"\n",
2021
"sys.path.append(\"..\")"
2122
]
2223
},
@@ -27,34 +28,28 @@
2728
"metadata": {},
2829
"outputs": [],
2930
"source": [
30-
"import os\n",
3131
"import glob\n",
3232
"import math\n",
33-
"import boto3\n",
34-
"import yaml\n",
33+
"import os\n",
3534
"import random\n",
36-
"import numpy as np\n",
37-
"import pandas as pd\n",
35+
"\n",
3836
"import geopandas as gpd\n",
37+
"import lancedb\n",
3938
"import matplotlib.pyplot as plt\n",
40-
"import xarray as xr\n",
39+
"import numpy as np\n",
40+
"import pandas as pd\n",
41+
"import pystac_client\n",
4142
"import rioxarray # noqa: F401\n",
43+
"import shapely\n",
4244
"import torch\n",
43-
"import stackstac\n",
44-
"from pystac_client import Client\n",
45-
"import pystac_client\n",
45+
"import yaml\n",
4646
"from box import Box\n",
47-
"import lancedb\n",
48-
"from pathlib import Path\n",
49-
"import shapely\n",
50-
"from einops import rearrange\n",
51-
"from torchvision.transforms import v2\n",
52-
"from stacchip.processors.prechip import normalize_timestamp\n",
53-
"from stacchip.indexer import NoStatsChipIndexer\n",
5447
"from stacchip.chipper import Chipper\n",
48+
"from stacchip.indexer import NoStatsChipIndexer\n",
49+
"from stacchip.processors.prechip import normalize_timestamp\n",
50+
"from torchvision.transforms import v2\n",
5551
"\n",
56-
"from src.datamodule import ClayDataModule\n",
57-
"from src.model_clay import CLAYModule\n"
52+
"from src.model_clay import CLAYModule"
5853
]
5954
},
6055
{
@@ -65,7 +60,9 @@
6560
"outputs": [],
6661
"source": [
6762
"# Query STAC catalog for NAIP data\n",
68-
"catalog = pystac_client.Client.open(\"https://planetarycomputer.microsoft.com/api/stac/v1\") #\"https://earth-search.aws.element84.com/v1\")\n",
63+
"catalog = pystac_client.Client.open(\n",
64+
" \"https://planetarycomputer.microsoft.com/api/stac/v1\"\n",
65+
") # \"https://earth-search.aws.element84.com/v1\")\n",
6966
"\n",
7067
"\n",
7168
"items = catalog.search(\n",
@@ -99,7 +96,7 @@
9996
"\n",
10097
" # Get first chip for the \"image\" asset key\n",
10198
" for chip_id in random.sample(range(0, len(chipper)), 5):\n",
102-
" chips.append(chipper[chip_id][\"image\"])\n"
99+
" chips.append(chipper[chip_id][\"image\"])"
103100
]
104101
},
105102
{
@@ -128,7 +125,7 @@
128125
}
129126
],
130127
"source": [
131-
"fig, ax = plt.subplots(1, 1, gridspec_kw={'wspace': 0.01, 'hspace': 0.01}, squeeze=True)\n",
128+
"fig, ax = plt.subplots(1, 1, gridspec_kw={\"wspace\": 0.01, \"hspace\": 0.01}, squeeze=True)\n",
132129
"\n",
133130
"chip = chips[0]\n",
134131
"# Visualize the data\n",
@@ -160,11 +157,10 @@
160157
" Parameters:\n",
161158
" stack (xarray.DataArray): The input data array containing band information.\n",
162159
" \"\"\"\n",
163-
" stack.sel(band=[1, 2, 3]).plot.imshow(\n",
164-
" rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n",
165-
" )\n",
160+
" stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n",
166161
" plt.show()\n",
167-
" \n",
162+
"\n",
163+
"\n",
168164
"def normalize_latlon(lat, lon):\n",
169165
" \"\"\"\n",
170166
" Normalize latitude and longitude to a range between -1 and 1.\n",
@@ -181,6 +177,7 @@
181177
"\n",
182178
" return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n",
183179
"\n",
180+
"\n",
184181
"def load_model(ckpt, device=\"cuda\"):\n",
185182
" \"\"\"\n",
186183
" Load a pretrained Clay model from a checkpoint.\n",
@@ -194,11 +191,16 @@
194191
" \"\"\"\n",
195192
" torch.set_default_device(device)\n",
196193
" model = CLAYModule.load_from_checkpoint(\n",
197-
" ckpt, metadata_path=\"../configs/metadata.yaml\", shuffle=False, mask_ratio=0, model_size=\"medium\"\n",
194+
" ckpt,\n",
195+
" metadata_path=\"../configs/metadata.yaml\",\n",
196+
" shuffle=False,\n",
197+
" mask_ratio=0,\n",
198+
" model_size=\"medium\",\n",
198199
" )\n",
199200
" model.eval()\n",
200201
" return model.to(device)\n",
201202
"\n",
203+
"\n",
202204
"def prep_datacube(stack, lat, lon, device):\n",
203205
" \"\"\"\n",
204206
" Prepare a data cube for model input.\n",
@@ -260,6 +262,7 @@
260262
" \"waves\": torch.tensor(waves, device=device),\n",
261263
" }\n",
262264
"\n",
265+
"\n",
263266
"def generate_embeddings(model, datacube):\n",
264267
" \"\"\"\n",
265268
" Generate embeddings from the model using the data cube.\n",
@@ -275,8 +278,8 @@
275278
" unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n",
276279
"\n",
277280
" # The first embedding is the class token, which is the\n",
278-
" # overall single embedding. \n",
279-
" return unmsk_patch[:, 0, :].cpu().numpy()\n"
281+
" # overall single embedding.\n",
282+
" return unmsk_patch[:, 0, :].cpu().numpy()"
280283
]
281284
},
282285
{
@@ -337,10 +340,10 @@
337340
" # Calculate the centroid\n",
338341
" centroid_x = (tile.x * tile).sum() / tile.sum()\n",
339342
" centroid_y = (tile.y * tile).sum() / tile.sum()\n",
340-
" \n",
343+
"\n",
341344
" lon = centroid_x.item()\n",
342345
" lat = centroid_y.item()\n",
343-
" \n",
346+
"\n",
344347
" datacube = prep_datacube(tile, lat, lon, model.device)\n",
345348
" embeddings_ = generate_embeddings(model, datacube)\n",
346349
" embeddings.append(embeddings_)\n",
@@ -359,15 +362,14 @@
359362
" box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n",
360363
"\n",
361364
" # Create the GeoDataFrame\n",
362-
" gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\")\n",
365+
" gdf = gpd.GeoDataFrame(\n",
366+
" data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\"\n",
367+
" )\n",
363368
"\n",
364369
" # Reproject to WGS84 (lon/lat coordinates)\n",
365370
" gdf = gdf.to_crs(epsg=4326)\n",
366371
"\n",
367-
" outpath = (\n",
368-
" f\"{outdir_embeddings}/\"\n",
369-
" f\"{fname[:-4]}.gpq\"\n",
370-
" )\n",
372+
" outpath = f\"{outdir_embeddings}/\" f\"{fname[:-4]}.gpq\"\n",
371373
" gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n",
372374
" print(\n",
373375
" f\"Saved {len(gdf)} rows of embeddings of \"\n",
@@ -438,7 +440,7 @@
438440
"data = []\n",
439441
"# Dataframe to find overlaps within\n",
440442
"gdfs = []\n",
441-
"idx = 0\n",
443+
"idx = 0\n",
442444
"for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n",
443445
" gdf = gpd.read_parquet(emb)\n",
444446
" gdf[\"year\"] = gdf.date.dt.year\n",
@@ -459,7 +461,7 @@
459461
" \"box\": row[\"box\"].bounds,\n",
460462
" }\n",
461463
" )\n",
462-
" idx += 1\n"
464+
" idx += 1"
463465
]
464466
},
465467
{
@@ -560,8 +562,10 @@
560562
" for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n",
561563
" row = df.iloc[i]\n",
562564
" path = row[\"path\"]\n",
563-
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(band=['red', 'green', 'blue'])\n",
564-
" chip = chip.squeeze().transpose('x', 'y', 'band')\n",
565+
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(\n",
566+
" band=[\"red\", \"green\", \"blue\"]\n",
567+
" )\n",
568+
" chip = chip.squeeze().transpose(\"x\", \"y\", \"band\")\n",
565569
" ax.imshow(chip)\n",
566570
" ax.set_title(f\"{row['idx']}\")\n",
567571
" ax.set_axis_off()\n",

0 commit comments

Comments
 (0)