diff --git a/nbs/v1-inference-simsearch-naip-stacchip.ipynb b/nbs/v1-inference-simsearch-naip-stacchip.ipynb new file mode 100644 index 00000000..8bc858d0 --- /dev/null +++ b/nbs/v1-inference-simsearch-naip-stacchip.ipynb @@ -0,0 +1,664 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a802ca09-12d2-46a6-b530-70c8be0448e6", + "metadata": {}, + "source": [ + "# NAIP Inference and Similarity Search with Clay v1\n", + "\n", + "This notebook walks through Clay model v1 inference on [NAIP (National Agriculture Imagery Program) data](https://naip-usdaonline.hub.arcgis.com/) and similarity search. The workflow includes loading and preprocessing data from STAC, tiling the images and encoding metadata, generating embeddings and querying across them for similar representations. The NAIP data comes in annual composites. We are using data from one year within a sampled region in San Francisco, California.\n", + "\n", + "The workflow includes the following steps:\n", + "\n", + "1. **Loading and Preprocessing Data**:\n", + " - Connect to a STAC (SpatioTemporal Asset Catalog) to query and download NAIP imagery for a specified region and time period.\n", + " - Preprocess the downloaded imagery, including tiling the images and extracting metadata.\n", + "\n", + "2. **Generating Embeddings**:\n", + " - Use a pretrained Clay model to generate embeddings for the preprocessed image tiles.\n", + "\n", + "3. **Saving Embeddings**:\n", + " - Save the generated embeddings along with the associated image data and select metadata in parquet format.\n", + "\n", + "4. **Similarity Search**:\n", + " - Load the saved embeddings into a LanceDB database.\n", + " - Perform a similarity search by querying the database with a randomly selected embedding.\n", + " - Retrieve and display the top similar images based on the similarity search." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0c141b9-4038-4542-832c-f71e04bd93c1", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "sys.path.append(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eabd5bef", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install stacchip==0.1.33" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48002199-2fab-4d85-aeba-25f651865f98", + "metadata": {}, + "outputs": [], + "source": [ + "import datetime\n", + "import glob\n", + "import math\n", + "import os\n", + "import random\n", + "\n", + "import geopandas as gpd\n", + "import lancedb\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pystac_client\n", + "import shapely\n", + "import torch\n", + "import yaml\n", + "from box import Box\n", + "from stacchip.chipper import Chipper\n", + "from stacchip.indexer import NoStatsChipIndexer\n", + "from stacchip.processors.prechip import normalize_timestamp\n", + "from torchvision.transforms import v2\n", + "\n", + "from src.model import ClayMAEModule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce6a21b7-cfe8-44f7-86ce-ce7852f963eb", + "metadata": {}, + "outputs": [], + "source": [ + "# Define the platform name and year for the NAIP data\n", + "PLATFORM_NAME = \"naip\"\n", + "YEAR = 2020" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71742792-0bf4-429e-ab00-122643a9fafd", + "metadata": {}, + "outputs": [], + "source": [ + "# Query STAC catalog for NAIP data\n", + "catalog = pystac_client.Client.open(\n", + " \"https://planetarycomputer.microsoft.com/api/stac/v1\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "279e0cdf-da24-411e-9a90-1b42327cc766", + "metadata": {}, + "outputs": [], + "source": [ + "# Perform a search on the STAC catalog,\n", + "# specifying the collection to search within (NAIP data),\n", + "# defining the bounding box for the search area (San Francisco region), and\n", + "# setting the date range for the search (entire year 2020).\n", + "# Also limit the search to a maximum of 100 items.\n", + "items = catalog.search(\n", + " collections=[PLATFORM_NAME],\n", + " bbox=[-122.6, 37.6, -122.35, 37.85],\n", + " datetime=f\"{YEAR}-01-01T00:00:00Z/{YEAR+1}-01-01T00:00:00Z\",\n", + " max_items=100,\n", + ")\n", + "\n", + "# Convert the search results to an item collection\n", + "items = items.item_collection()\n", + "\n", + "# Convert the item collection to a list for easier manipulation\n", + "items_list = list(items)\n", + "\n", + "# Randomly shuffle the list of items to ensure random sampling\n", + "random.shuffle(items_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccb089e6-36d7-41a4-8297-f5517aa42065", + "metadata": {}, + "outputs": [], + "source": [ + "chip_images = [] # List to hold chip pixels\n", + "chip_bounds = [] # List to hold chip bounds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d058e157-33f5-4276-aaec-6f68677bf3d2", + "metadata": {}, + "outputs": [], + "source": [ + "for item in items_list[:2]:\n", + " print(f\"Working on {item}\")\n", + "\n", + " # Index the chips in the item\n", + " indexer = NoStatsChipIndexer(item)\n", + "\n", + " # Instantiate the chipper\n", + " chipper = Chipper(\n", + " indexer, asset_blacklist=[\"thumbnail\", \"tilejson\", \"rendered_preview\"]\n", + " )\n", + "\n", + " # Get 5 randomly sampled chips from the total\n", + " # number of chips within this item's entire image\n", + " for chip_id in random.sample(range(0, len(chipper)), 25):\n", + " x_index, y_index, chip = chipper[chip_id]\n", + " chip_images.append(chip[\"image\"])\n", + " chip_bounds.append(indexer.get_chip_bbox(x_index, y_index))" + ] + }, + { + "cell_type": "markdown", + "id": "b61aaad7", + "metadata": {}, + "source": [ + "Visualize a generated image chip." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fabb53d-0132-4fba-befb-d16be8116428", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, gridspec_kw={\"wspace\": 0.01, \"hspace\": 0.01}, squeeze=True)\n", + "\n", + "chip = chip_images[-1]\n", + "\n", + "# Visualize the data\n", + "ax.imshow(chip[:3].swapaxes(0, 1).swapaxes(1, 2))\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "589611d4", + "metadata": {}, + "source": [ + "Below are some functions we will rely on to prepare the data cubes, generate embeddings, and plot subsets of the chipped images for visualization purposes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a0112ed-7bbc-434f-8914-7160a3a2c239", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_rgb(stack):\n", + " \"\"\"\n", + " Plot the RGB bands of the given stack.\n", + "\n", + " Parameters:\n", + " stack (xarray.DataArray): The input data array containing band information.\n", + " \"\"\"\n", + " stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e23b0258-2b41-4c45-a0df-0d2e62809067", + "metadata": {}, + "outputs": [], + "source": [ + "def normalize_latlon(lat, lon):\n", + " \"\"\"\n", + " Normalize latitude and longitude to a range between -1 and 1.\n", + "\n", + " Parameters:\n", + " lat (float): Latitude value.\n", + " lon (float): Longitude value.\n", + "\n", + " Returns:\n", + " tuple: Normalized latitude and longitude values.\n", + " \"\"\"\n", + " lat = lat * np.pi / 180\n", + " lon = lon * np.pi / 180\n", + "\n", + " return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7fba068e-ee25-43c0-976f-c2beac2d39ae", + "metadata": {}, + "outputs": [], + "source": [ + "def load_model(ckpt, device=\"cuda\"):\n", + " \"\"\"\n", + " Load a pretrained Clay model from a checkpoint.\n", + "\n", + " Parameters:\n", + " ckpt (str): Path to the model checkpoint.\n", + " device (str): Device to load the model onto (default is 'cuda').\n", + "\n", + " Returns:\n", + " model: Loaded model.\n", + " \"\"\"\n", + " torch.set_default_device(device)\n", + " model = ClayMAEModule.load_from_checkpoint(\n", + " ckpt, metadata_path=\"../configs/metadata.yaml\", shuffle=False, mask_ratio=0\n", + " )\n", + " model.eval()\n", + " return model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "029e3fb5-c6eb-4b8b-9430-d3e2a6f0b3a2", + "metadata": {}, + "outputs": [], + "source": [ + "def prep_datacube(image, lat, lon, date, gsd, device):\n", + " \"\"\"\n", + " Prepare a data cube for model input.\n", + "\n", + " Parameters:\n", + " image (np.array): The input image array.\n", + " lat (float): Latitude value for the location.\n", + " lon (float): Longitude value for the location.\n", + " device (str): Device to load the data onto.\n", + "\n", + " Returns:\n", + " dict: Prepared data cube with normalized values and embeddings.\n", + " \"\"\"\n", + " platform = \"naip\"\n", + "\n", + " # Extract mean, std, and wavelengths from metadata\n", + " metadata = Box(yaml.safe_load(open(\"../configs/metadata.yaml\")))\n", + " mean = []\n", + " std = []\n", + " waves = []\n", + " bands = [\"red\", \"green\", \"blue\", \"nir\"]\n", + " for band_name in bands:\n", + " mean.append(metadata[platform].bands.mean[band_name])\n", + " std.append(metadata[platform].bands.std[band_name])\n", + " waves.append(metadata[platform].bands.wavelength[band_name])\n", + "\n", + " transform = v2.Compose(\n", + " [\n", + " v2.Normalize(mean=mean, std=std),\n", + " ]\n", + " )\n", + "\n", + " # Prep datetimes embedding\n", + " times = normalize_timestamp(date)\n", + " week_norm = times[0]\n", + " hour_norm = times[1]\n", + "\n", + " # Prep lat/lon embedding\n", + " latlons = normalize_latlon(lat, lon)\n", + " lat_norm = latlons[0]\n", + " lon_norm = latlons[1]\n", + "\n", + " # Prep pixels\n", + " pixels = torch.from_numpy(image.astype(np.float32))\n", + " pixels = transform(pixels)\n", + " pixels = pixels.unsqueeze(0)\n", + "\n", + " # Prepare additional information\n", + " return {\n", + " \"pixels\": pixels.to(device),\n", + " \"time\": torch.tensor(\n", + " np.hstack((week_norm, hour_norm)),\n", + " dtype=torch.float32,\n", + " device=device,\n", + " ).unsqueeze(0),\n", + " \"latlon\": torch.tensor(\n", + " np.hstack((lat_norm, lon_norm)), dtype=torch.float32, device=device\n", + " ).unsqueeze(0),\n", + " \"gsd\": torch.tensor(gsd, device=device),\n", + " \"waves\": torch.tensor(waves, device=device),\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bba433de-df3a-4e8e-880a-653352b9d4bb", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_embeddings(model, datacube):\n", + " \"\"\"\n", + " Generate embeddings from the model.\n", + "\n", + " Parameters:\n", + " model (ClayMAEModule): The pretrained model.\n", + " datacube (dict): Prepared data cube.\n", + "\n", + " Returns:\n", + " numpy.ndarray: Generated embeddings.\n", + " \"\"\"\n", + " with torch.no_grad():\n", + " unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n", + "\n", + " # The first embedding is the class token, which is the\n", + " # overall single embedding.\n", + " return unmsk_patch[:, 0, :].cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "419e0d80-7250-4397-b383-db027813536a", + "metadata": {}, + "outputs": [], + "source": [ + "outdir_embeddings = \"../data/embeddings/\"\n", + "os.makedirs(outdir_embeddings, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "id": "d8f01975", + "metadata": {}, + "source": [ + "### Load the trained Clay v1 model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59d0bdbc", + "metadata": {}, + "outputs": [], + "source": [ + "# Download the pretrained model from\n", + "# https://huggingface.co/made-with-clay/Clay/blob/main/clay-v1-base.ckpt\n", + "# and put it in a checkpoints folder.\n", + "model = load_model(\n", + " ckpt=\"../checkpoints/clay-v1-base.ckpt\",\n", + " device=torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\"),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e5ae07f7", + "metadata": {}, + "source": [ + "### Generate embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05409114-2bba-4497-8cbb-a7303a8d5be6", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "embeddings = []\n", + "i = 0\n", + "for tile, box in zip(chip_images, chip_bounds):\n", + " date = datetime.datetime.strptime(f\"{YEAR}-06-01\", \"%Y-%m-%d\")\n", + " gsd = 0.6\n", + "\n", + " lon, lat = chip_bounds[0].centroid.coords[0]\n", + "\n", + " datacube = prep_datacube(\n", + " np.array(tile), lat, lon, pd.to_datetime(f\"{YEAR}-06-01\"), gsd, model.device\n", + " )\n", + " embeddings_ = generate_embeddings(model, datacube)\n", + " embeddings.append(embeddings_)\n", + "\n", + " data = {\n", + " \"source_url\": str(i),\n", + " \"date\": pd.to_datetime(arg=date, format=\"%Y-%m-%d\"),\n", + " \"embeddings\": [np.ascontiguousarray(embeddings_.squeeze())],\n", + " \"image\": [np.ascontiguousarray(np.array(tile.transpose(1, 2, 0)).flatten())],\n", + " }\n", + "\n", + " # Create the GeoDataFrame\n", + " gdf = gpd.GeoDataFrame(data, geometry=[box], crs=\"EPSG:4326\")\n", + "\n", + " outpath = f\"{outdir_embeddings}/{i}.gpq\"\n", + " gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n", + " print(\n", + " f\"Saved {len(gdf)} rows of embeddings of \"\n", + " f\"shape {gdf.embeddings.iloc[0].shape} to {outpath}\"\n", + " )\n", + " i += 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ec8c89b-38fe-44e3-9780-e1df8a3f41c2", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Created {len(embeddings)} embeddings of shape {embeddings[0].shape[1]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a9cf665b", + "metadata": {}, + "source": [ + "### Run a similarity search to identify similar embeddings\n", + "We will select a random index to search with and plot the corresponding RGB images from the search results. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99132892-c6bd-4c01-a92c-f54e1cba1921", + "metadata": {}, + "outputs": [], + "source": [ + "# Connect to the embeddings database\n", + "db = lancedb.connect(outdir_embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edf872d0-6b90-4102-9dc3-c726527c19dc", + "metadata": {}, + "outputs": [], + "source": [ + "# Data for DB table\n", + "data = []\n", + "# Dataframe to find overlaps within\n", + "gdfs = []\n", + "idx = 0\n", + "for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n", + " gdf = gpd.read_parquet(emb)\n", + " gdf[\"year\"] = gdf.date.dt.year\n", + " gdf[\"tile\"] = gdf[\"source_url\"]\n", + " gdf[\"idx\"] = idx\n", + " gdf[\"box\"] = [shapely.geometry.box(*geom.bounds) for geom in gdf.geometry]\n", + " gdfs.append(gdf)\n", + "\n", + " for _, row in gdf.iterrows():\n", + " data.append(\n", + " {\n", + " \"vector\": row[\"embeddings\"],\n", + " \"path\": row[\"source_url\"],\n", + " \"tile\": row[\"tile\"],\n", + " \"date\": row[\"date\"],\n", + " \"year\": int(row[\"year\"]),\n", + " \"idx\": row[\"idx\"],\n", + " \"box\": row[\"box\"].bounds,\n", + " \"image\": row[\"image\"],\n", + " }\n", + " )\n", + " idx += 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfab482d-56e7-4d2b-af72-e4ddce555667", + "metadata": {}, + "outputs": [], + "source": [ + "# Combine the geodataframes into one\n", + "embeddings_gdf = pd.concat(gdfs, ignore_index=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41e8633a-4db6-4487-958b-48ed955426cc", + "metadata": {}, + "outputs": [], + "source": [ + "# Drop existing table if any\n", + "try:\n", + " db.drop_table(\"clay-v001\")\n", + "except FileNotFoundError:\n", + " pass\n", + "db.table_names()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19fde8ff-9970-44f9-9798-3d701961621f", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a new table with the embeddings data\n", + "tbl = db.create_table(\"clay-v001\", data=data, mode=\"overwrite\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5e4fe11-87c8-4830-bb44-4c63ac981329", + "metadata": {}, + "outputs": [], + "source": [ + "# Select a random embedding for the search query\n", + "idx = random.randint(0, len(embeddings_gdf))\n", + "v = tbl.to_pandas().iloc[idx][\"vector\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3428bb77-9c85-4d28-9675-ad1718dbff53", + "metadata": {}, + "outputs": [], + "source": [ + "# Perform the search\n", + "search_x_images = 6\n", + "result = tbl.search(query=v).limit(search_x_images).to_pandas()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5529e840-052b-4a43-b4b7-41ed3062db62", + "metadata": {}, + "outputs": [], + "source": [ + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f347667-ed4d-48e6-9f79-855b866073dc", + "metadata": {}, + "outputs": [], + "source": [ + "def plot(df, cols=4, save=False):\n", + " \"\"\"\n", + " Plot the top similar images.\n", + "\n", + " Parameters:\n", + " df (pandas.DataFrame): DataFrame containing the search results.\n", + " cols (int): Number of columns to display in the plot.\n", + " \"\"\"\n", + " fig, axs = plt.subplots(1, cols, figsize=(20, 10))\n", + " i = 0\n", + " for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n", + " # row = df.iloc[i]\n", + " chip = np.array(row[\"image\"]).reshape(256, 256, 4)\n", + " chip = chip[:, :, :3]\n", + " ax.imshow(chip)\n", + " ax.set_title(f\"{row['idx']}\")\n", + " i += 1\n", + " plt.tight_layout()\n", + " if save:\n", + " fig.savefig(\"similar.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc343525-ebbd-447d-a00b-0ae7582e88c6", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the top similar images\n", + "plot(result, search_x_images)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf91604b-0bfd-4a71-ad12-dc13b6eba660", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/v1-inference-simsearch-naip.ipynb b/nbs/v1-inference-simsearch-naip.ipynb new file mode 100644 index 00000000..d3e1dc8d --- /dev/null +++ b/nbs/v1-inference-simsearch-naip.ipynb @@ -0,0 +1,664 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "db0c561a", + "metadata": {}, + "source": [ + "# NAIP inference and similarity search with Clay v1\n", + "This notebook walks through Clay model v1 inference on NAIP (National Agriculture Imagery Program) data and similarity search. The workflow includes loading and preprocessing data from STAC, tiling the images and encoding metadata, generating embeddings and querying across them for similar representations. The NAIP data comes in annual composites. We are using data from one year within a sampled region in San Francisco, California.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0c141b9-4038-4542-832c-f71e04bd93c1", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "sys.path.append(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48002199-2fab-4d85-aeba-25f651865f98", + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "import math\n", + "import os\n", + "import random\n", + "\n", + "import geopandas as gpd\n", + "import lancedb\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import rioxarray # noqa: F401\n", + "import shapely\n", + "import torch\n", + "import xarray as xr\n", + "import yaml\n", + "from box import Box\n", + "from pystac_client import Client\n", + "from stacchip.processors.prechip import normalize_timestamp\n", + "from torchvision.transforms import v2\n", + "\n", + "from src.model_clay_v1 import ClayMAEModule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f950c34e-2c8a-4d02-9d72-c161321c826e", + "metadata": {}, + "outputs": [], + "source": [ + "# STAC API endpoint and platform details\n", + "STAC_API = \"https://planetarycomputer.microsoft.com/api/stac/v1\"\n", + "PLATFORM_NAME = \"naip\"\n", + "\n", + "# Directory to save the downloaded data\n", + "save_dir = \"./data_naip_test/ca/2020/sf/\"\n", + "os.makedirs(save_dir, exist_ok=True)\n", + "\n", + "YEAR = 2020\n", + "\n", + "# STAC API search query\n", + "search_query = {\n", + " \"collections\": [PLATFORM_NAME],\n", + " \"bbox\": [-122.6, 37.6, -122.35, 37.85], # Part of San Francisco, CA\n", + " \"datetime\": f\"{YEAR}-01-01T00:00:00Z/{YEAR+1}-01-01T00:00:00Z\",\n", + "}\n", + "\n", + "client = Client.open(STAC_API)\n", + "items = client.search(**search_query)\n", + "\n", + "stackstac_datasets = []\n", + "granule_names = []\n", + "\n", + "# Iterate over the granule names and fetch the corresponding StackSTAC data arrays\n", + "for item in items.get_all_items():\n", + " assets = item.assets\n", + " dataset = rioxarray.open_rasterio(item.assets[\"image\"].href).sel(band=[1, 2, 3, 4])\n", + " granule_name = item.assets[\"image\"].href.split(\"/\")[-1]\n", + " stackstac_datasets.append(dataset)\n", + " granule_names.append(granule_name)" + ] + }, + { + "cell_type": "markdown", + "id": "589611d4", + "metadata": {}, + "source": [ + "Below are some functions we will rely on to prepare the data cubes, generate embeddings and plot subsets of the tiled images for visualization purposes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be577195-74e1-4b5e-b871-ef2b4e14a662", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_rgb(stack):\n", + " \"\"\"\n", + " Plot the RGB bands of the given stack.\n", + "\n", + " Parameters:\n", + " stack (xarray.DataArray): The input data array containing band information.\n", + " \"\"\"\n", + " stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n", + " plt.show()\n", + "\n", + "\n", + "def normalize_latlon(lat, lon):\n", + " \"\"\"\n", + " Normalize latitude and longitude to a range between -1 and 1.\n", + "\n", + " Parameters:\n", + " lat (float): Latitude value.\n", + " lon (float): Longitude value.\n", + "\n", + " Returns:\n", + " tuple: Normalized latitude and longitude values.\n", + " \"\"\"\n", + " lat = lat * np.pi / 180\n", + " lon = lon * np.pi / 180\n", + "\n", + " return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n", + "\n", + "\n", + "def load_model(ckpt, device=\"cuda\"):\n", + " \"\"\"\n", + " Load a pretrained Clay model from a checkpoint.\n", + "\n", + " Parameters:\n", + " ckpt (str): Path to the model checkpoint.\n", + " device (str): Device to load the model onto (default is 'cuda').\n", + "\n", + " Returns:\n", + " model: Loaded model.\n", + " \"\"\"\n", + " torch.set_default_device(device)\n", + " model = ClayMAEModule.load_from_checkpoint(\n", + " ckpt, metadata_path=\"../configs/metadata.yaml\", shuffle=False, mask_ratio=0\n", + " )\n", + " model.eval()\n", + " return model.to(device)\n", + "\n", + "\n", + "def prep_datacube(stack, lat, lon, device):\n", + " \"\"\"\n", + " Prepare a data cube for model input.\n", + "\n", + " Parameters:\n", + " stack (xarray.DataArray): The input data stack.\n", + " lat (float): Latitude value for the location.\n", + " lon (float): Longitude value for the location.\n", + " device (str): Device to load the data onto.\n", + "\n", + " Returns:\n", + " dict: Prepared data cube with normalized values and embeddings.\n", + " \"\"\"\n", + " platform = \"naip\"\n", + "\n", + " # Extract mean, std, and wavelengths from metadata\n", + " metadata = Box(yaml.safe_load(open(\"../configs/metadata.yaml\")))\n", + " mean = []\n", + " std = []\n", + " waves = []\n", + " for band in stack.band:\n", + " mean.append(metadata[platform].bands.mean[str(band.values)])\n", + " std.append(metadata[platform].bands.std[str(band.values)])\n", + " waves.append(metadata[platform].bands.wavelength[str(band.values)])\n", + "\n", + " transform = v2.Compose(\n", + " [\n", + " v2.Normalize(mean=mean, std=std),\n", + " ]\n", + " )\n", + "\n", + " # Prep datetimes embedding\n", + " datetimes = stack.time.values.astype(\"datetime64[s]\").tolist()\n", + " times = [normalize_timestamp(dat) for dat in datetimes]\n", + " week_norm = [dat[0] for dat in times]\n", + " hour_norm = [dat[1] for dat in times]\n", + "\n", + " # Prep lat/lon embedding\n", + " latlons = [normalize_latlon(lat, lon)] * len(times)\n", + " lat_norm = [dat[0] for dat in latlons]\n", + " lon_norm = [dat[1] for dat in latlons]\n", + "\n", + " # Prep pixels\n", + " pixels = torch.from_numpy(stack.data.astype(np.float32))\n", + " pixels = transform(pixels)\n", + "\n", + " # Prepare additional information\n", + " return {\n", + " \"pixels\": pixels.to(device),\n", + " \"time\": torch.tensor(\n", + " np.hstack((week_norm, hour_norm)),\n", + " dtype=torch.float32,\n", + " device=device,\n", + " ),\n", + " \"latlon\": torch.tensor(\n", + " np.hstack((lat_norm, lon_norm)), dtype=torch.float32, device=device\n", + " ),\n", + " \"gsd\": torch.tensor(stack.gsd.values, device=device),\n", + " \"waves\": torch.tensor(waves, device=device),\n", + " }\n", + "\n", + "\n", + "def generate_embeddings(model, datacube):\n", + " \"\"\"\n", + " Generate embeddings from the model using the data cube.\n", + "\n", + " Parameters:\n", + " model (ClayMAEModule): The pretrained model.\n", + " datacube (dict): Prepared data cube.\n", + "\n", + " Returns:\n", + " numpy.ndarray: Generated embeddings.\n", + " \"\"\"\n", + " with torch.no_grad():\n", + " unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n", + "\n", + " # The first embedding is the class token, which is the\n", + " # overall single embedding.\n", + " return unmsk_patch[:, 0, :].cpu().numpy()\n", + "\n", + "\n", + "def tile_dataset(dataset, granule_name):\n", + " \"\"\"\n", + " Tile dataset into 256x256 image chips and drop any excess border regions.\n", + "\n", + " Parameters:\n", + " dataset (xarray.DataArray): Input dataset to be tiled.\n", + " granule_name (str): Name of the granule.\n", + "\n", + " Returns:\n", + " tuple: List of tiles and their corresponding names.\n", + " \"\"\"\n", + " dataset = dataset.transpose(\"band\", \"y\", \"x\")\n", + "\n", + " # Crop the dataset to remove excess border regions\n", + " cropped_dataset = dataset.isel(x=slice(1, -1), y=slice(1, -1))\n", + "\n", + " # Determine the number of tiles in x and y dimensions\n", + " num_x_tiles = cropped_dataset.x.size // 256\n", + " num_y_tiles = cropped_dataset.y.size // 256\n", + "\n", + " # Iterate over each tile\n", + " tiles = []\n", + " tile_names = []\n", + " for x_idx in range(num_x_tiles):\n", + " for y_idx in range(num_y_tiles):\n", + " # Calculate the coordinates for this tile\n", + " x_start = x_idx * 256\n", + " y_start = y_idx * 256\n", + " x_end = x_start + 256\n", + " y_end = y_start + 256\n", + "\n", + " # Extract the tile from the cropped dataset\n", + " tile = cropped_dataset.isel(\n", + " x=slice(x_start, x_end), y=slice(y_start, y_end)\n", + " )\n", + "\n", + " # Calculate the centroid\n", + " # centroid_x = (tile.x * tile).sum() / tile.sum()\n", + " # centroid_y = (tile.y * tile).sum() / tile.sum()\n", + "\n", + " # lon = centroid_x.item()\n", + " # lat = centroid_y.item()\n", + "\n", + " tile = tile.assign_coords(band=[\"red\", \"green\", \"blue\", \"nir\"])\n", + " tile_save = tile\n", + "\n", + " time_coord = xr.DataArray([\"2020-01-01\"], dims=\"time\", name=\"time\")\n", + " tile = tile.expand_dims(time=[0])\n", + " tile = tile.assign_coords(time=time_coord)\n", + "\n", + " gsd_coord = xr.DataArray([0.6], dims=\"gsd\", name=\"gsd\")\n", + " tile = tile.expand_dims(gsd=[0])\n", + " tile = tile.assign_coords(gsd=gsd_coord)\n", + "\n", + " tile_name = f\"{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n", + "\n", + " tile_path = f\"{save_dir}/{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n", + " tile_save.rio.to_raster(tile_path)\n", + " tiles.append(tile)\n", + " tile_names.append(tile_name)\n", + "\n", + " return tiles, tile_names" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d59b1297", + "metadata": {}, + "outputs": [], + "source": [ + "# Flag to control whether to make tiles or load existing ones\n", + "make_tiles = False\n", + "\n", + "if make_tiles:\n", + " tiles_ = []\n", + " tile_names_ = []\n", + "\n", + " # Tile each dataset\n", + " for dataset, granule_name in zip(stackstac_datasets, granule_names):\n", + " tiles, tile_names = tile_dataset(dataset, granule_name)\n", + " tiles_.append(tiles)\n", + " tile_names_.append(tile_names)\n", + " # Flatten sublists\n", + " tiles__ = [tile for tile in tiles for tile_ in tiles_]\n", + " tile_names__ = [tile for tile in tile_names for tile_ in tile_names_]\n", + "else:\n", + " tiles__ = []\n", + " tile_names__ = []\n", + " for filename in os.listdir(save_dir):\n", + " if filename.endswith(\".tif\"):\n", + " tile_names__.append(filename)\n", + " file_path = os.path.join(save_dir, filename)\n", + " data_array = rioxarray.open_rasterio(file_path)\n", + " tiles__.append(data_array)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db409223-881a-4e2f-8710-55851ee1f974", + "metadata": {}, + "outputs": [], + "source": [ + "len(tiles__)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "419e0d80-7250-4397-b383-db027813536a", + "metadata": {}, + "outputs": [], + "source": [ + "outdir_embeddings = \"./data_naip_test/ca/2020/sf_embeddings\"\n", + "os.makedirs(outdir_embeddings, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "id": "d8f01975", + "metadata": {}, + "source": [ + "### Load the trained Clay v1 model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59d0bdbc", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the pretrained model\n", + "model = load_model(\n", + " # ckpt=\"s3://clay-model-ckpt/v0.5.3/mae_v0.5.3_epoch-29_val-loss-0.3073.ckpt\",\n", + " # ckpt=\"../checkpoints/v0.5.3/mae_v0.5.3_epoch-08_val-loss-0.3150.ckpt\",\n", + " ckpt=\"s3://clay-model-ckpt/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\",\n", + " device=\"cuda\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e5ae07f7", + "metadata": {}, + "source": [ + "### Generate embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05409114-2bba-4497-8cbb-a7303a8d5be6", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "embeddings = []\n", + "i = 0\n", + "for tile, fname in zip(tiles__, tile_names__):\n", + " # Calculate the centroid\n", + " centroid_x = (tile.x * tile).sum() / tile.sum()\n", + " centroid_y = (tile.y * tile).sum() / tile.sum()\n", + "\n", + " lon = centroid_x.item()\n", + " lat = centroid_y.item()\n", + "\n", + " datacube = prep_datacube(tile, lat, lon, model.device)\n", + " embeddings_ = generate_embeddings(model, datacube)\n", + " embeddings.append(embeddings_)\n", + "\n", + " date = tile.time\n", + " data = {\n", + " \"source_url\": str(fname[:-4]),\n", + " \"date\": pd.to_datetime(arg=date, format=\"%Y-%m-%d\").astype(\n", + " dtype=\"date32[day][pyarrow]\"\n", + " ),\n", + " \"embeddings\": [np.ascontiguousarray(embeddings_.squeeze())],\n", + " }\n", + "\n", + " # Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)\n", + " box_ = tile.rio.bounds()\n", + " box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n", + "\n", + " # Create the GeoDataFrame\n", + " gdf = gpd.GeoDataFrame(\n", + " data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\"\n", + " )\n", + "\n", + " # Reproject to WGS84 (lon/lat coordinates)\n", + " gdf = gdf.to_crs(epsg=4326)\n", + "\n", + " outpath = f\"{outdir_embeddings}/\" f\"{fname[:-4]}.gpq\"\n", + " gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n", + " print(\n", + " f\"Saved {len(gdf)} rows of embeddings of \"\n", + " f\"shape {gdf.embeddings.iloc[0].shape} to {outpath}\"\n", + " )\n", + " i += 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ec8c89b-38fe-44e3-9780-e1df8a3f41c2", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bf9e1fc-b432-4ec3-a028-f85d2ff57469", + "metadata": {}, + "outputs": [], + "source": [ + "len(embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "a9cf665b", + "metadata": {}, + "source": [ + "### Run a similarity search to identify similar embeddings\n", + "We will select a random index to search with and plot the corresponding RGB images from the search results. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99132892-c6bd-4c01-a92c-f54e1cba1921", + "metadata": {}, + "outputs": [], + "source": [ + "# Connect to the embeddings database\n", + "db = lancedb.connect(\"embeddings\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edf872d0-6b90-4102-9dc3-c726527c19dc", + "metadata": {}, + "outputs": [], + "source": [ + "# Data for DB table\n", + "data = []\n", + "# Dataframe to find overlaps within\n", + "gdfs = []\n", + "idx = 0\n", + "for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n", + " gdf = gpd.read_parquet(emb)\n", + " gdf[\"year\"] = gdf.date.dt.year\n", + " gdf[\"tile\"] = gdf[\"source_url\"]\n", + " gdf[\"idx\"] = idx\n", + " gdf[\"box\"] = [shapely.geometry.box(*geom.bounds) for geom in gdf.geometry]\n", + " gdfs.append(gdf)\n", + "\n", + " for _, row in gdf.iterrows():\n", + " data.append(\n", + " {\n", + " \"vector\": row[\"embeddings\"],\n", + " \"path\": row[\"source_url\"],\n", + " \"tile\": row[\"tile\"],\n", + " \"date\": row[\"date\"],\n", + " \"year\": int(row[\"year\"]),\n", + " \"idx\": row[\"idx\"],\n", + " \"box\": row[\"box\"].bounds,\n", + " }\n", + " )\n", + " idx += 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfab482d-56e7-4d2b-af72-e4ddce555667", + "metadata": {}, + "outputs": [], + "source": [ + "# Combine patch level geodataframes into one\n", + "embeddings_gdf = pd.concat(gdfs, ignore_index=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41e8633a-4db6-4487-958b-48ed955426cc", + "metadata": {}, + "outputs": [], + "source": [ + "# Drop existing table if any\n", + "db.drop_table(\"clay-v001\")\n", + "db.table_names()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19fde8ff-9970-44f9-9798-3d701961621f", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a new table with the embeddings data\n", + "tbl = db.create_table(\"clay-v001\", data=data, mode=\"overwrite\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5e4fe11-87c8-4830-bb44-4c63ac981329", + "metadata": {}, + "outputs": [], + "source": [ + "# Select a random embedding for the search query\n", + "idx = random.randint(0, len(embeddings_gdf))\n", + "v = tbl.to_pandas().iloc[idx][\"vector\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3428bb77-9c85-4d28-9675-ad1718dbff53", + "metadata": {}, + "outputs": [], + "source": [ + "# Perform the search\n", + "result = tbl.search(query=v).limit(10).to_pandas()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5529e840-052b-4a43-b4b7-41ed3062db62", + "metadata": {}, + "outputs": [], + "source": [ + "result.path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f347667-ed4d-48e6-9f79-855b866073dc", + "metadata": {}, + "outputs": [], + "source": [ + "def plot(df, cols=10):\n", + " \"\"\"\n", + " Plot the top similar images.\n", + "\n", + " Parameters:\n", + " df (pandas.DataFrame): DataFrame containing the search results.\n", + " cols (int): Number of columns to display in the plot.\n", + " \"\"\"\n", + " fig, axs = plt.subplots(1, cols, figsize=(20, 10))\n", + " i = 0\n", + " for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n", + " # row = df.iloc[i]\n", + " path = row[\"path\"]\n", + " chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(\n", + " band=[\"red\", \"green\", \"blue\"]\n", + " )\n", + " chip = chip.squeeze().transpose(\"x\", \"y\", \"band\")\n", + " ax.imshow(chip)\n", + " ax.set_title(f\"{row['idx']}\")\n", + " ax.set_axis_off()\n", + " i += 1\n", + " plt.tight_layout()\n", + " fig.savefig(\"similar.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc343525-ebbd-447d-a00b-0ae7582e88c6", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the top similar images\n", + "plot(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b4159d5-6a0a-41c6-8d3f-6078d419a420", + "metadata": {}, + "outputs": [], + "source": [ + "len(embeddings_gdf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf91604b-0bfd-4a71-ad12-dc13b6eba660", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}