Skip to content

Commit 93f4174

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3d0ea79 commit 93f4174

File tree

1 file changed

+59
-65
lines changed

1 file changed

+59
-65
lines changed

nbs/240508-inference-naip.ipynb

+59-65
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"outputs": [],
99
"source": [
1010
"import sys\n",
11+
"\n",
1112
"sys.path.append(\"..\")"
1213
]
1314
},
@@ -18,34 +19,27 @@
1819
"metadata": {},
1920
"outputs": [],
2021
"source": [
21-
"import os\n",
2222
"import glob\n",
2323
"import math\n",
24-
"import boto3\n",
25-
"import yaml\n",
24+
"import os\n",
2625
"import random\n",
27-
"import numpy as np\n",
28-
"import pandas as pd\n",
26+
"\n",
2927
"import geopandas as gpd\n",
30-
"from shapely import Point\n",
31-
"from sklearn import decomposition\n",
28+
"import lancedb\n",
3229
"import matplotlib.pyplot as plt\n",
33-
"import xarray as xr\n",
30+
"import numpy as np\n",
31+
"import pandas as pd\n",
3432
"import rioxarray # noqa: F401\n",
33+
"import shapely # .geometry import Point, Polygon, box\n",
3534
"import torch\n",
36-
"import stackstac\n",
37-
"from pystac_client import Client\n",
38-
"import pystac_client\n",
35+
"import xarray as xr\n",
36+
"import yaml\n",
3937
"from box import Box\n",
40-
"import lancedb\n",
41-
"from pathlib import Path\n",
42-
"import shapely #.geometry import Point, Polygon, box\n",
43-
"from einops import rearrange\n",
44-
"from torchvision.transforms import v2\n",
38+
"from pystac_client import Client\n",
4539
"from stacchip.processors.prechip import normalize_timestamp\n",
40+
"from torchvision.transforms import v2\n",
4641
"\n",
47-
"from src.datamodule import ClayDataModule\n",
48-
"from src.model_clay_v1 import ClayMAEModule\n"
42+
"from src.model_clay_v1 import ClayMAEModule"
4943
]
5044
},
5145
{
@@ -56,11 +50,10 @@
5650
"outputs": [],
5751
"source": [
5852
"def plot_rgb(stack):\n",
59-
" stack.sel(band=[1, 2, 3]).plot.imshow(\n",
60-
" rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n",
61-
" )\n",
53+
" stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n",
6254
" plt.show()\n",
63-
" \n",
55+
"\n",
56+
"\n",
6457
"def normalize_latlon(lat, lon):\n",
6558
" lat = lat * np.pi / 180\n",
6659
" lon = lon * np.pi / 180\n",
@@ -130,13 +123,13 @@
130123
"\n",
131124
"\n",
132125
"def generate_embeddings(model, datacube):\n",
133-
" #print(datacube)\n",
126+
" # print(datacube)\n",
134127
" with torch.no_grad():\n",
135128
" unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n",
136129
"\n",
137130
" # The first embedding is the class token, which is the\n",
138131
" # overall single embedding. We extract that for PCA below.\n",
139-
" return unmsk_patch[:, 0, :].cpu().numpy()\n"
132+
" return unmsk_patch[:, 0, :].cpu().numpy()"
140133
]
141134
},
142135
{
@@ -172,11 +165,10 @@
172165
"for item in items.get_all_items():\n",
173166
" assets = item.assets\n",
174167
" dataset = rioxarray.open_rasterio(item.assets[\"image\"].href).sel(band=[1, 2, 3, 4])\n",
175-
" granule_name = item.assets[\"image\"].href.split('/')[-1]\n",
168+
" granule_name = item.assets[\"image\"].href.split(\"/\")[-1]\n",
176169
" stackstac_datasets.append(dataset)\n",
177170
" granule_names.append(granule_name)\n",
178-
" \n",
179-
" \n",
171+
"\n",
180172
"\n",
181173
"# Function to tile dataset into 256x256 image chips and drop any excess border regions\n",
182174
"def tile_dataset(dataset, granule_name):\n",
@@ -202,72 +194,73 @@
202194
" y_end = y_start + 256\n",
203195
"\n",
204196
" # Extract the tile from the cropped dataset\n",
205-
" tile = cropped_dataset.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))\n",
206-
" \n",
197+
" tile = cropped_dataset.isel(\n",
198+
" x=slice(x_start, x_end), y=slice(y_start, y_end)\n",
199+
" )\n",
200+
"\n",
207201
" # Calculate the centroid\n",
208202
" centroid_x = (tile.x * tile).sum() / tile.sum()\n",
209203
" centroid_y = (tile.y * tile).sum() / tile.sum()\n",
210-
" \n",
204+
"\n",
211205
" # Print or use the centroid coordinates\n",
212-
" #print(\"Centroid X:\", centroid_x.item())\n",
213-
" #print(\"Centroid Y:\", centroid_y.item())\n",
214-
" \n",
206+
" # print(\"Centroid X:\", centroid_x.item())\n",
207+
" # print(\"Centroid Y:\", centroid_y.item())\n",
208+
"\n",
215209
" lon = centroid_x.item()\n",
216210
" lat = centroid_y.item()\n",
217211
"\n",
218-
" tile = tile.assign_coords(band=['red','green','blue','nir'])\n",
212+
" tile = tile.assign_coords(band=[\"red\", \"green\", \"blue\", \"nir\"])\n",
219213
" tile_save = tile\n",
220214
"\n",
221-
" time_coord = xr.DataArray(['2020-01-01'], dims='time', name='time')\n",
215+
" time_coord = xr.DataArray([\"2020-01-01\"], dims=\"time\", name=\"time\")\n",
222216
"\n",
223217
" # Assign the time coordinate to the DataArray\n",
224218
" tile = tile.expand_dims(time=[0])\n",
225219
" tile = tile.assign_coords(time=time_coord)\n",
226220
"\n",
227-
" gsd_coord = xr.DataArray([0.6], dims='gsd', name='gsd')\n",
221+
" gsd_coord = xr.DataArray([0.6], dims=\"gsd\", name=\"gsd\")\n",
228222
"\n",
229223
" # Assign the time coordinate to the DataArray\n",
230224
" tile = tile.expand_dims(gsd=[0])\n",
231225
" tile = tile.assign_coords(gsd=gsd_coord)\n",
232226
"\n",
233227
" tile_name = f\"{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n",
234-
" #name_coord = xr.DataArray(tile_name, dims='filename', name='filename')\n",
228+
" # name_coord = xr.DataArray(tile_name, dims='filename', name='filename')\n",
235229
"\n",
236230
" # Assign the time coordinate to the DataArray\n",
237-
" #tile = tile.expand_dims(filename=[0])\n",
238-
" #tile = tile.assign_coords(filename=name_coord)\n",
231+
" # tile = tile.expand_dims(filename=[0])\n",
232+
" # tile = tile.assign_coords(filename=name_coord)\n",
239233
"\n",
240-
" #print(tile)\n",
234+
" # print(tile)\n",
241235
"\n",
242236
" # Save the tile as a GeoTIFF\n",
243237
" tile_path = f\"{save_dir}/{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n",
244238
" tile_save.rio.to_raster(tile_path)\n",
245239
" tiles.append(tile)\n",
246240
" tile_names.append(tile_name)\n",
247-
" \n",
241+
"\n",
248242
" return tiles, tile_names\n",
249-
" \n",
243+
"\n",
250244
"\n",
251245
"make_tiles = False\n",
252246
"\n",
253247
"if make_tiles:\n",
254248
" tiles_ = []\n",
255249
" tile_names_ = []\n",
256-
" \n",
257-
" \n",
250+
"\n",
258251
" # Tile each dataset\n",
259252
" for dataset, granule_name in zip(stackstac_datasets, granule_names):\n",
260253
" tiles, tile_names = tile_dataset(dataset, granule_name)\n",
261254
" tiles_.append(tiles)\n",
262255
" tile_names_.append(tile_names)\n",
263-
" #tiles, tile_names = tile_dataset(stackstac_datasets[0], granule_names[0])\n",
256+
" # tiles, tile_names = tile_dataset(stackstac_datasets[0], granule_names[0])\n",
264257
" tiles__ = [tile for tile in tiles for tile_ in tiles_]\n",
265258
" tile_names__ = [tile for tile in tile_names for tile_ in tile_names_]\n",
266259
"else:\n",
267260
" tiles__ = []\n",
268261
" tile_names__ = []\n",
269262
" for filename in os.listdir(save_dir):\n",
270-
" if filename.endswith(\".tif\"): \n",
263+
" if filename.endswith(\".tif\"):\n",
271264
" tile_names__.append(filename)\n",
272265
" file_path = os.path.join(save_dir, filename)\n",
273266
" data_array = rioxarray.open_rasterio(file_path)\n",
@@ -317,7 +310,7 @@
317310
"source": [
318311
"model = load_model(\n",
319312
" # ckpt=\"s3://clay-model-ckpt/v0.5.3/mae_v0.5.3_epoch-29_val-loss-0.3073.ckpt\",\n",
320-
" #ckpt=\"../checkpoints/v0.5.3/mae_v0.5.3_epoch-08_val-loss-0.3150.ckpt\",\n",
313+
" # ckpt=\"../checkpoints/v0.5.3/mae_v0.5.3_epoch-08_val-loss-0.3150.ckpt\",\n",
321314
" ckpt=\"s3://clay-model-ckpt/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\",\n",
322315
" device=\"cuda\",\n",
323316
")\n",
@@ -327,14 +320,14 @@
327320
" # Calculate the centroid\n",
328321
" centroid_x = (tile.x * tile).sum() / tile.sum()\n",
329322
" centroid_y = (tile.y * tile).sum() / tile.sum()\n",
330-
" \n",
323+
"\n",
331324
" # Print or use the centroid coordinates\n",
332-
" #print(\"Centroid X:\", centroid_x.item())\n",
333-
" #print(\"Centroid Y:\", centroid_y.item())\n",
334-
" \n",
325+
" # print(\"Centroid X:\", centroid_x.item())\n",
326+
" # print(\"Centroid Y:\", centroid_y.item())\n",
327+
"\n",
335328
" lon = centroid_x.item()\n",
336329
" lat = centroid_y.item()\n",
337-
" \n",
330+
"\n",
338331
" datacube = prep_datacube(tile, lat, lon, model.device)\n",
339332
" embeddings_ = generate_embeddings(model, datacube)\n",
340333
" embeddings.append(embeddings_)\n",
@@ -355,21 +348,20 @@
355348
" box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n",
356349
"\n",
357350
" # Create the GeoDataFrame\n",
358-
" gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\")\n",
351+
" gdf = gpd.GeoDataFrame(\n",
352+
" data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\"\n",
353+
" )\n",
359354
"\n",
360355
" # Reproject to WGS84 (lon/lat coordinates)\n",
361356
" gdf = gdf.to_crs(epsg=4326)\n",
362357
"\n",
363-
" outpath = (\n",
364-
" f\"{outdir_embeddings}/\"\n",
365-
" f\"{fname[:-4]}.gpq\"\n",
366-
" )\n",
358+
" outpath = f\"{outdir_embeddings}/\" f\"{fname[:-4]}.gpq\"\n",
367359
" gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n",
368360
" print(\n",
369361
" f\"Saved {len(gdf)} rows of embeddings of \"\n",
370362
" f\"shape {gdf.embeddings.iloc[0].shape} to {outpath}\"\n",
371363
" )\n",
372-
" i=i+1"
364+
" i = i + 1"
373365
]
374366
},
375367
{
@@ -435,7 +427,7 @@
435427
"data = []\n",
436428
"# Dataframe to find overlaps within\n",
437429
"gdfs = []\n",
438-
"idx = 0\n",
430+
"idx = 0\n",
439431
"for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n",
440432
" gdf = gpd.read_parquet(emb)\n",
441433
" gdf[\"year\"] = gdf.date.dt.year\n",
@@ -456,7 +448,7 @@
456448
" \"box\": row[\"box\"].bounds,\n",
457449
" }\n",
458450
" )\n",
459-
" idx = idx+1"
451+
" idx = idx + 1"
460452
]
461453
},
462454
{
@@ -572,21 +564,23 @@
572564
"source": [
573565
"def plot(df, cols=10):\n",
574566
" fig, axs = plt.subplots(1, cols, figsize=(20, 10))\n",
575-
" i=0\n",
567+
" i = 0\n",
576568
" for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n",
577569
" row = df.iloc[i]\n",
578570
" path = row[\"path\"]\n",
579-
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(band=['red', 'green', 'blue']) #[1,2,3])\n",
580-
" #chip = tiles__[row[\"idx\"]].sel(band=['red', 'green', 'blue'])\n",
571+
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(\n",
572+
" band=[\"red\", \"green\", \"blue\"]\n",
573+
" ) # [1,2,3])\n",
574+
" # chip = tiles__[row[\"idx\"]].sel(band=['red', 'green', 'blue'])\n",
581575
" width = chip.shape[-1]\n",
582576
" height = chip.shape[-1]\n",
583577
" chip = chip.squeeze()\n",
584-
" chip = chip.transpose('x', 'y', 'band')\n",
578+
" chip = chip.transpose(\"x\", \"y\", \"band\")\n",
585579
"\n",
586580
" ax.imshow(chip)\n",
587581
" ax.set_title(f\"{row['idx']}\")\n",
588582
" ax.set_axis_off()\n",
589-
" i=i+1\n",
583+
" i = i + 1\n",
590584
" plt.tight_layout()\n",
591585
" fig.savefig(\"similar.png\")"
592586
]

0 commit comments

Comments
 (0)