Skip to content

Commit 64e5479

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d0ec7d9 commit 64e5479

File tree

1 file changed

+43
-42
lines changed

1 file changed

+43
-42
lines changed

nbs/240508-inference-naip.ipynb

+43-42
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"outputs": [],
1818
"source": [
1919
"import sys\n",
20+
"\n",
2021
"sys.path.append(\"..\")"
2122
]
2223
},
@@ -27,32 +28,27 @@
2728
"metadata": {},
2829
"outputs": [],
2930
"source": [
30-
"import os\n",
3131
"import glob\n",
3232
"import math\n",
33-
"import boto3\n",
34-
"import yaml\n",
33+
"import os\n",
3534
"import random\n",
36-
"import numpy as np\n",
37-
"import pandas as pd\n",
35+
"\n",
3836
"import geopandas as gpd\n",
37+
"import lancedb\n",
3938
"import matplotlib.pyplot as plt\n",
40-
"import xarray as xr\n",
39+
"import numpy as np\n",
40+
"import pandas as pd\n",
4141
"import rioxarray # noqa: F401\n",
42+
"import shapely\n",
4243
"import torch\n",
43-
"import stackstac\n",
44-
"from pystac_client import Client\n",
45-
"import pystac_client\n",
44+
"import xarray as xr\n",
45+
"import yaml\n",
4646
"from box import Box\n",
47-
"import lancedb\n",
48-
"from pathlib import Path\n",
49-
"import shapely\n",
50-
"from einops import rearrange\n",
51-
"from torchvision.transforms import v2\n",
47+
"from pystac_client import Client\n",
5248
"from stacchip.processors.prechip import normalize_timestamp\n",
49+
"from torchvision.transforms import v2\n",
5350
"\n",
54-
"from src.datamodule import ClayDataModule\n",
55-
"from src.model_clay_v1 import ClayMAEModule\n"
51+
"from src.model_clay_v1 import ClayMAEModule"
5652
]
5753
},
5854
{
@@ -69,11 +65,10 @@
6965
" Parameters:\n",
7066
" stack (xarray.DataArray): The input data array containing band information.\n",
7167
" \"\"\"\n",
72-
" stack.sel(band=[1, 2, 3]).plot.imshow(\n",
73-
" rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n",
74-
" )\n",
68+
" stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n",
7569
" plt.show()\n",
76-
" \n",
70+
"\n",
71+
"\n",
7772
"def normalize_latlon(lat, lon):\n",
7873
" \"\"\"\n",
7974
" Normalize latitude and longitude to a range between -1 and 1.\n",
@@ -90,6 +85,7 @@
9085
"\n",
9186
" return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n",
9287
"\n",
88+
"\n",
9389
"def load_model(ckpt, device=\"cuda\"):\n",
9490
" \"\"\"\n",
9591
" Load a pretrained Clay model from a checkpoint.\n",
@@ -108,6 +104,7 @@
108104
" model.eval()\n",
109105
" return model.to(device)\n",
110106
"\n",
107+
"\n",
111108
"def prep_datacube(stack, lat, lon, device):\n",
112109
" \"\"\"\n",
113110
" Prepare a data cube for model input.\n",
@@ -169,6 +166,7 @@
169166
" \"waves\": torch.tensor(waves, device=device),\n",
170167
" }\n",
171168
"\n",
169+
"\n",
172170
"def generate_embeddings(model, datacube):\n",
173171
" \"\"\"\n",
174172
" Generate embeddings from the model using the data cube.\n",
@@ -184,8 +182,8 @@
184182
" unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n",
185183
"\n",
186184
" # The first embedding is the class token, which is the\n",
187-
" # overall single embedding. \n",
188-
" return unmsk_patch[:, 0, :].cpu().numpy()\n"
185+
" # overall single embedding.\n",
186+
" return unmsk_patch[:, 0, :].cpu().numpy()"
189187
]
190188
},
191189
{
@@ -227,23 +225,25 @@
227225
" y_end = y_start + 256\n",
228226
"\n",
229227
" # Extract the tile from the cropped dataset\n",
230-
" tile = cropped_dataset.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))\n",
231-
" \n",
228+
" tile = cropped_dataset.isel(\n",
229+
" x=slice(x_start, x_end), y=slice(y_start, y_end)\n",
230+
" )\n",
231+
"\n",
232232
" # Calculate the centroid\n",
233233
" centroid_x = (tile.x * tile).sum() / tile.sum()\n",
234234
" centroid_y = (tile.y * tile).sum() / tile.sum()\n",
235-
" \n",
235+
"\n",
236236
" lon = centroid_x.item()\n",
237237
" lat = centroid_y.item()\n",
238238
"\n",
239-
" tile = tile.assign_coords(band=['red','green','blue','nir'])\n",
239+
" tile = tile.assign_coords(band=[\"red\", \"green\", \"blue\", \"nir\"])\n",
240240
" tile_save = tile\n",
241241
"\n",
242-
" time_coord = xr.DataArray(['2020-01-01'], dims='time', name='time')\n",
242+
" time_coord = xr.DataArray([\"2020-01-01\"], dims=\"time\", name=\"time\")\n",
243243
" tile = tile.expand_dims(time=[0])\n",
244244
" tile = tile.assign_coords(time=time_coord)\n",
245245
"\n",
246-
" gsd_coord = xr.DataArray([0.6], dims='gsd', name='gsd')\n",
246+
" gsd_coord = xr.DataArray([0.6], dims=\"gsd\", name=\"gsd\")\n",
247247
" tile = tile.expand_dims(gsd=[0])\n",
248248
" tile = tile.assign_coords(gsd=gsd_coord)\n",
249249
"\n",
@@ -253,7 +253,7 @@
253253
" tile_save.rio.to_raster(tile_path)\n",
254254
" tiles.append(tile)\n",
255255
" tile_names.append(tile_name)\n",
256-
" \n",
256+
"\n",
257257
" return tiles, tile_names"
258258
]
259259
},
@@ -291,7 +291,7 @@
291291
"for item in items.get_all_items():\n",
292292
" assets = item.assets\n",
293293
" dataset = rioxarray.open_rasterio(item.assets[\"image\"].href).sel(band=[1, 2, 3, 4])\n",
294-
" granule_name = item.assets[\"image\"].href.split('/')[-1]\n",
294+
" granule_name = item.assets[\"image\"].href.split(\"/\")[-1]\n",
295295
" stackstac_datasets.append(dataset)\n",
296296
" granule_names.append(granule_name)"
297297
]
@@ -321,7 +321,7 @@
321321
" tiles__ = []\n",
322322
" tile_names__ = []\n",
323323
" for filename in os.listdir(save_dir):\n",
324-
" if filename.endswith(\".tif\"): \n",
324+
" if filename.endswith(\".tif\"):\n",
325325
" tile_names__.append(filename)\n",
326326
" file_path = os.path.join(save_dir, filename)\n",
327327
" data_array = rioxarray.open_rasterio(file_path)\n",
@@ -391,10 +391,10 @@
391391
" # Calculate the centroid\n",
392392
" centroid_x = (tile.x * tile).sum() / tile.sum()\n",
393393
" centroid_y = (tile.y * tile).sum() / tile.sum()\n",
394-
" \n",
394+
"\n",
395395
" lon = centroid_x.item()\n",
396396
" lat = centroid_y.item()\n",
397-
" \n",
397+
"\n",
398398
" datacube = prep_datacube(tile, lat, lon, model.device)\n",
399399
" embeddings_ = generate_embeddings(model, datacube)\n",
400400
" embeddings.append(embeddings_)\n",
@@ -413,15 +413,14 @@
413413
" box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n",
414414
"\n",
415415
" # Create the GeoDataFrame\n",
416-
" gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\")\n",
416+
" gdf = gpd.GeoDataFrame(\n",
417+
" data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\"\n",
418+
" )\n",
417419
"\n",
418420
" # Reproject to WGS84 (lon/lat coordinates)\n",
419421
" gdf = gdf.to_crs(epsg=4326)\n",
420422
"\n",
421-
" outpath = (\n",
422-
" f\"{outdir_embeddings}/\"\n",
423-
" f\"{fname[:-4]}.gpq\"\n",
424-
" )\n",
423+
" outpath = f\"{outdir_embeddings}/\" f\"{fname[:-4]}.gpq\"\n",
425424
" gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n",
426425
" print(\n",
427426
" f\"Saved {len(gdf)} rows of embeddings of \"\n",
@@ -494,7 +493,7 @@
494493
"data = []\n",
495494
"# Dataframe to find overlaps within\n",
496495
"gdfs = []\n",
497-
"idx = 0\n",
496+
"idx = 0\n",
498497
"for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n",
499498
" gdf = gpd.read_parquet(emb)\n",
500499
" gdf[\"year\"] = gdf.date.dt.year\n",
@@ -515,7 +514,7 @@
515514
" \"box\": row[\"box\"].bounds,\n",
516515
" }\n",
517516
" )\n",
518-
" idx += 1\n"
517+
" idx += 1"
519518
]
520519
},
521520
{
@@ -637,8 +636,10 @@
637636
" for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n",
638637
" row = df.iloc[i]\n",
639638
" path = row[\"path\"]\n",
640-
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(band=['red', 'green', 'blue'])\n",
641-
" chip = chip.squeeze().transpose('x', 'y', 'band')\n",
639+
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(\n",
640+
" band=[\"red\", \"green\", \"blue\"]\n",
641+
" )\n",
642+
" chip = chip.squeeze().transpose(\"x\", \"y\", \"band\")\n",
642643
" ax.imshow(chip)\n",
643644
" ax.set_title(f\"{row['idx']}\")\n",
644645
" ax.set_axis_off()\n",

0 commit comments

Comments
 (0)