diff --git a/environment.yml b/environment.yml index 699df9df..ac0ddeaf 100644 --- a/environment.yml +++ b/environment.yml @@ -7,33 +7,36 @@ dependencies: - einops~=0.7.0 - fiona~=1.9.5 - geopandas-base~=0.14.1 - - h5netcdf~=1.3.0 - - jupyter-book~=1.0.0 - - jupyterlab~=4.0.7 - jsonargparse~=4.27.0 - - lancedb~=0.10.2 - lightning~=2.1.0 - matplotlib-base~=3.8.2 - planetary-computer~=1.0.0 - python-box~=7.1.0 - - pytorch~=2.1.0 # [osx] - - pytorch~=2.1.0 *cuda12* # [linux] + - pytorch~=2.3.1 # [osx] + - pytorch~=2.3.1 *cuda12* # [linux] - python~=3.11.0 - pyarrow~=16.1.0 - - rioxarray~=0.15.0 - rasterio~=1.3.10 - s3fs~=2024.3.1 - scikit-image~=0.22.0 - scikit-learn~=1.4.0 - stackstac~=0.5.0 - timm~=0.9.16 - - torchdata~=0.7.1 - - torchgeo~=0.5.2 - - torchvision~=0.16.1 + - torchvision~=0.18.1 - transformers~=4.35.2 - typeshed-client~=2.4.0 - vit-pytorch~=1.6.4 - - wandb~=0.15.12 - zarr~=2.16.1 + - pip: + - geoarrow-pyarrow==0.1.2 + - jupyter-book==1.0.2 + - jupyterlab==4.2.4 + - onnx==1.16.1 + - onnxscript + - onnxruntime + - torchdata==0.7.1 + - torchgeo==0.5.2 + - stacchip==0.1.35 + - wandb==0.17.5 platforms: - linux-64 diff --git a/finetune/embedder/factory.py b/finetune/embedder/factory.py new file mode 100644 index 00000000..bf3ee6e4 --- /dev/null +++ b/finetune/embedder/factory.py @@ -0,0 +1,303 @@ +"""Export the Clay model to ONNX and pytorch ExportedProgram format. + +This script exports the Clay model to ONNX and pytorch ExportedProgram format +for deployment. The model is exported with dynamic shapes for inference. + +How to use: + +```bash +python -m finetune.embedder.factory \ + --img_size 256 \ + --ckpt_path checkpoints/clay-v1-base.ckpt \ + --device cuda \ + --name clay-v1-encoder.onnx \ + --onnx +# exports Clay encoder to ONNX format that can handle chips of size 256x256 +# for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ & Sentinel 1. +``` + +```bash +python -m finetune.embedder.factory \ + --img_size 224 \ + --ckpt_path checkpoints/clay-v1-base.ckpt \ + --device cuda \ + --name clay-v1-encoder.pt2 \ + --ep +# exports Clay encoder to pytorch ExportedProgram format that can handle chips +# of size 224x224 for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ +# & Sentinel 1. +``` + +""" + +import argparse +import re +import warnings +from pathlib import Path + +import torch +from einops import repeat +from torch import nn +from torch.export import Dim + +from src.model import Encoder +from src.utils import posemb_sincos_2d_with_gsd + +warnings.filterwarnings("ignore", category=UserWarning) + + +class EmbeddingEncoder(Encoder): + """Clay Encoder without mask and shuffle.""" + + def __init__( # noqa: PLR0913 + self, + img_size, + patch_size, + dim, + depth, + heads, + dim_head, + mlp_ratio, + ): + super().__init__( + mask_ratio=0.0, + shuffle=False, + patch_size=patch_size, + dim=dim, + depth=depth, + heads=heads, + dim_head=dim_head, + mlp_ratio=mlp_ratio, + ) + self.img_size = img_size + + # Using fixed grid size for inference + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + + def add_encodings(self, patches, time, latlon, gsd): + """Add position encoding to the patches""" + B, L, D = patches.shape + + grid_size = self.grid_size + + pos_encoding = ( + posemb_sincos_2d_with_gsd( + h=grid_size, + w=grid_size, + dim=(self.dim - 8), + gsd=gsd, + ) + .to(patches.device) + .detach() + ) # [L (D - 8)] + + time_latlon = torch.hstack((time, latlon)).to(patches.device).detach() # [B 8] + + pos_encoding = repeat(pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] + time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] + pos_metadata_encoding = torch.cat( + (pos_encoding, time_latlon), dim=-1 + ) # [B L D] + + patches = patches + pos_metadata_encoding # [B L D] + [B L D] -> [B L D] + return patches # [B L D] + + # def forward(self, cube, time, latlon, waves, gsd): + def forward(self, datacube): + cube, time, latlon, gsd, waves = ( + datacube["pixels"], # [B C H W] + datacube["time"], # [B 2] + datacube["latlon"], # [B 2] + datacube["gsd"], # 1 + datacube["waves"], # [N] + ) # [B C H W] + B, C, H, W = cube.shape + + patches, _ = self.to_patch_embed( + cube, waves + ) # [B L D] - patchify & create embeddings per patch + + # Add time & latlon as encoding to patches + patches = self.add_encodings( + patches, + time, + latlon, + gsd, + ) # [B L D] - add position encoding to the embeddings + + # Add class tokens + cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] + patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] + + # pass the patches through the transformer + patches = self.transformer(patches) # [B (1 + L) D] + + # get the cls token + embeddings = patches[:, 0, :] # [B D] + + return embeddings + + +class Embedder(nn.Module): + def __init__(self, img_size=256, ckpt_path=None, device="cpu"): + super().__init__() + self.clay_encoder = ( + EmbeddingEncoder( # Default parameters for the Clay base model + img_size=img_size, + patch_size=8, + dim=768, + depth=12, + heads=12, + dim_head=64, + mlp_ratio=4.0, + ).to(device) + ) + self.img_size = img_size + self.device = torch.device(device) + self.load_clay_weights(ckpt_path) + + def load_clay_weights(self, ckpt_path): + "Load the weights from the Clay model encoder." + ckpt = torch.load(ckpt_path, map_location=self.device) + state_dict = ckpt.get("state_dict") + state_dict = { + re.sub(r"^model\.encoder\.", "", name): param + for name, param in state_dict.items() + if name.startswith("model.encoder") + } + + with torch.no_grad(): + for name, param in self.clay_encoder.named_parameters(): + if name in state_dict and param.size() == state_dict[name].size(): + param.data.copy_(state_dict[name]) # Copy the weights + else: + print(f"No matching parameter for {name} with size {param.size()}") + + for param in self.clay_encoder.parameters(): + param.requires_grad = False + + self.clay_encoder.eval() + + def forward(self, datacube): + embeddings = self.clay_encoder(datacube) + + return embeddings + + def fake_datacube(self): + "Generate a fake datacube for model export." + dummy_datacube = { + "pixels": torch.randn(2, 3, self.img_size, self.img_size), + "time": torch.randn(2, 4), + "latlon": torch.randn(2, 4), + "waves": torch.randn(3), + "gsd": torch.randn(1), + } + dummy_datacube = {k: v.to(self.device) for k, v in dummy_datacube.items()} + return dummy_datacube + + def export_to_onnx(self, name): + "Save the model to ONNX format." + + datacube = self.fake_datacube() + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + + # Export the model to ONNX format + onnx_program = torch.onnx.dynamo_export( + self.eval(), datacube, export_options=export_options + ) + + # Save the exported model + onnx_program.save(f"checkpoints/compiled/{name}") + print(f"Model exported to ONNX format: checkpoints/compiled/{name}") + + return onnx_program + + def export_to_torchep(self, name): + "Save the model to pytorch ExportedProgram format." + + datacube = self.fake_datacube() + + # dynamic shapes for model export + batch_size = Dim("batch_size", min=2, max=1000) + channel_bands = Dim("channel_bands", min=1, max=10) + dynamic_shapes = { + "datacube": { + "pixels": {0: batch_size, 1: channel_bands}, + "time": {0: batch_size}, + "latlon": {0: batch_size}, + "waves": {0: channel_bands}, + "gsd": {0: None}, + } + } + + # Export the model to pytorch ExportedProgram format + ep = torch.export.export( + self.eval(), + (datacube,), + dynamic_shapes=dynamic_shapes, + strict=True, + ) + + # Save the exported model + torch.export.save(ep, f"checkpoints/compiled/{name}") + print( + f"Model exported to pytorch ExportedProgram format: checkpoints/compiled/{name}" # noqa: E501 + ) + + return ep + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export the Clay model.") + parser.add_argument( + "--img_size", + type=int, + default=256, + help="Image size for the model", + ) + parser.add_argument( + "--ckpt_path", + type=str, + default="checkpoints/clay-v1-base.ckpt", + help="Path to the Clay model checkpoint", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for the model", + ) + parser.add_argument( + "--name", + type=str, + default="clay-base.pt", + help="Name of the exported model", + ) + parser.add_argument( + "--onnx", + action="store_true", + help="Export the model to ONNX format", + ) + parser.add_argument( + "--ep", + action="store_true", + help="Export the model to pytorch ExportedProgram format", + ) + + args = parser.parse_args() + + Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True) + embedder = Embedder( + img_size=args.img_size, + ckpt_path=args.ckpt_path, + device=args.device, + ) + + if args.onnx: + embedder.export_to_onnx(args.name) + elif args.ep: + embedder.export_to_torchep(args.name) + else: + print("Please specify the format to export the model.") + parser.print_help() diff --git a/finetune/embedder/how-to-embed.ipynb b/finetune/embedder/how-to-embed.ipynb new file mode 100644 index 00000000..5f482846 --- /dev/null +++ b/finetune/embedder/how-to-embed.ipynb @@ -0,0 +1,494 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "d9960547-640d-425c-8180-fc5523a80e42", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import os\n", + "import requests\n", + "import warnings\n", + "\n", + "import geoarrow.pyarrow as ga\n", + "import numpy as np\n", + "import pystac_client\n", + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "import torch\n", + "import yaml\n", + "from box import Box\n", + "from torchvision.transforms import v2\n", + "\n", + "from stacchip.indexer import Sentinel2Indexer\n", + "from stacchip.chipper import Chipper\n", + "\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "markdown", + "id": "598fec81-2cc1-4c5a-9e46-7c46a5591484", + "metadata": {}, + "source": [ + "### Find data for AOI\n", + "The first step is to find STAC items of imagery that we want to use to create embeddings. In this example we are going to use Earth Genome's composite dataset which comes with a great STAC catalog.\n", + "\n", + "We are also going to create embeddings along time so that we have multiple embeddings for the same location at different moments in time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e1d46ee-40f6-49f5-99ad-83819339561e", + "metadata": {}, + "outputs": [], + "source": [ + "# Point over Monchique Portugal\n", + "lat, lon = 37.30939, -8.57207\n", + "\n", + "# Dates of a large forest fire\n", + "start = \"2018-07-01\"\n", + "end = \"2018-09-01\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7825318-23f3-449f-9104-eae6562a55ab", + "metadata": {}, + "outputs": [], + "source": [ + "# Optimize GDAL settings for cloud optimized reading\n", + "os.environ[\"GDAL_DISABLE_READDIR_ON_OPEN\"] = \"EMPTY_DIR\"\n", + "os.environ[\"AWS_REQUEST_PAYER\"] = \"requester\"\n", + "\n", + "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", + "COLLECTION = \"sentinel-2-l2a\"\n", + "\n", + "# Search the catalogue\n", + "catalog = pystac_client.Client.open(STAC_API)\n", + "search = catalog.search(\n", + " collections=[COLLECTION],\n", + " datetime=f\"{start}/{end}\",\n", + " bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),\n", + " max_items=100,\n", + " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", + ")\n", + "\n", + "all_items = search.get_all_items()\n", + "\n", + "# Reduce to one per date (there might be some duplicates\n", + "# based on the location)\n", + "items = []\n", + "dates = []\n", + "for item in all_items:\n", + " if item.datetime.date() not in dates:\n", + " items.append(item)\n", + " dates.append(item.datetime.date())\n", + "\n", + "print(f\"Found {len(items)} items\")" + ] + }, + { + "cell_type": "markdown", + "id": "600f3cfb-ce4e-4409-ae15-20f3a7107a62", + "metadata": {}, + "source": [ + "To speed up processing in this example, we limit the number of chips to 3 per Sentinel-2 scene. Remove this limit in a real use case." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "183975c7-8afb-49ef-8e70-790265719aea", + "metadata": {}, + "outputs": [], + "source": [ + "chips = []\n", + "datetimes = []\n", + "bboxs = []\n", + "chip_ids = []\n", + "item_ids = []\n", + "\n", + "for item in items:\n", + " print(f\"Working on {item}\")\n", + "\n", + " # Index the chips in the item\n", + " indexer = Sentinel2Indexer(item)\n", + "\n", + " # Instanciate the chipper\n", + " chipper = Chipper(indexer, assets=[\"red\", \"green\", \"blue\", \"nir\", \"scl\"])\n", + "\n", + " # Get first chip for the \"image\" asset key\n", + " for idx, (x, y, chip) in enumerate(chipper):\n", + " if idx > 2:\n", + " break\n", + " del chip[\"scl\"]\n", + " chips.append(chip)\n", + " datetimes.append(item.datetime)\n", + " bboxs.append(indexer.get_chip_bbox(x, y))\n", + " chip_ids.append((x, y))\n", + " item_ids.append(item.id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71902ab7-3320-43cd-85c3-362c2500f241", + "metadata": {}, + "outputs": [], + "source": [ + "pixels = np.array([np.array(list(chip.values())).squeeze() for chip in chips])\n", + "pixels.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f7ce367-4e12-4648-bb79-119b4f50ead8", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract mean, std, and wavelengths from metadata\n", + "platform = \"sentinel-2-l2a\"\n", + "# Retrieve the file content from the URL\n", + "\n", + "url = (\n", + " \"https://raw.githubusercontent.com/Clay-foundation/model/main/configs/metadata.yaml\"\n", + ")\n", + "response = requests.get(url, allow_redirects=True)\n", + "\n", + "# Convert bytes to string\n", + "content = response.content.decode(\"utf-8\")\n", + "\n", + "# Load the yaml\n", + "content = yaml.safe_load(content)\n", + "\n", + "metadata = Box(content)\n", + "mean = []\n", + "std = []\n", + "waves = []\n", + "# Use the band names to get the correct values in the correct order.\n", + "for band in chips[0].keys():\n", + " mean.append(metadata[platform].bands.mean[band])\n", + " std.append(metadata[platform].bands.std[band])\n", + " waves.append(metadata[platform].bands.wavelength[band])\n", + "\n", + "# Prepare the normalization transform function using the mean and std values.\n", + "transform = v2.Compose(\n", + " [\n", + " v2.Normalize(mean=mean, std=std),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8ec8c2d-ecb9-42a2-9e8c-3f95c67ef07b", + "metadata": {}, + "outputs": [], + "source": [ + "def normalize_timestamp(date):\n", + " week = date.isocalendar().week * 2 * np.pi / 52\n", + " hour = date.hour * 2 * np.pi / 24\n", + "\n", + " return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour))\n", + "\n", + "\n", + "times = [normalize_timestamp(dat) for dat in datetimes]\n", + "week_norm = [dat[0] for dat in times]\n", + "hour_norm = [dat[1] for dat in times]\n", + "\n", + "\n", + "# Prep lat/lon embedding using the\n", + "def normalize_latlon(lat, lon):\n", + " lat = lat * np.pi / 180\n", + " lon = lon * np.pi / 180\n", + "\n", + " return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n", + "\n", + "\n", + "latlons = [normalize_latlon(lat, lon)] * len(times)\n", + "lat_norm = [dat[0] for dat in latlons]\n", + "lon_norm = [dat[1] for dat in latlons]\n", + "\n", + "# Prep gsd\n", + "gsd = [10]\n", + "\n", + "# Normalize pixels\n", + "pixels = transform(pixels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2640eb17-a85c-4972-8d5d-e45e9ed8eba5", + "metadata": {}, + "outputs": [], + "source": [ + "datacube = {\n", + " \"pixels\": torch.tensor(pixels, dtype=torch.float32),\n", + " \"time\": torch.tensor(np.hstack((week_norm, hour_norm)), dtype=torch.float32),\n", + " \"latlon\": torch.tensor(np.hstack((lat_norm, lon_norm)), dtype=torch.float32),\n", + " \"waves\": torch.tensor(waves, dtype=torch.float32),\n", + " \"gsd\": torch.tensor(gsd, dtype=torch.float32),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f6711a9-e7ed-44d5-add7-2c3a498cd422", + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in datacube.items():\n", + " print(k, v.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "83243912-a2a8-4fa5-a39c-a9c3b07c7569", + "metadata": {}, + "source": [ + "### Clay Embedder\n", + "\n", + "#### Load the embedder that is stored in ExportedProgram format using **cpu**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4eb468af-d468-46aa-a8fb-23ff95c56288", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder-cpu.pt2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9eb797f7-5238-49e0-9950-e85f10132454", + "metadata": {}, + "outputs": [], + "source": [ + "ep_embedder_cpu = torch.export.load(\"clay-v1-encoder-cpu.pt2\").module()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eefe4811-7290-47c3-a10e-45257e6d42e0", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "with torch.no_grad():\n", + " embeddings = ep_embedder_cpu(datacube)\n", + "datacube[\"pixels\"].shape, embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "8e927b01-c855-4172-a4d9-2c10ba794ed4", + "metadata": {}, + "source": [ + "For each chip, we have an embedding of size `768`" + ] + }, + { + "cell_type": "markdown", + "id": "fa0810b4-34ad-490e-bbcd-c0c3288f017c", + "metadata": {}, + "source": [ + "#### Load the embedder that is stored in ExportedProgram format using **gpu**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c1bbfd4-7dc6-4ad0-8a0b-b3745a9f35ca", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder.pt2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e285a543-20ab-44ba-b676-2303284dc477", + "metadata": {}, + "outputs": [], + "source": [ + "datacube = {k: v.to(\"cuda\") for k, v in datacube.items()}\n", + "ep_embedder = torch.export.load(\"clay-v1-encoder.pt2\").module()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edefee90-e6b8-4701-bb5d-2bf7febc806c", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "with torch.no_grad():\n", + " embeddings = ep_embedder(datacube)\n", + "datacube[\"pixels\"].shape, embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "196f2121-46b5-4b02-94d3-75e648c329c3", + "metadata": {}, + "source": [ + "For each chip, we have an embedding of size `768`" + ] + }, + { + "cell_type": "markdown", + "id": "5b1cb0f9-a434-419b-a88b-4d4edd84fea6", + "metadata": {}, + "source": [ + "#### Load the embedder that is stored in ONNX format using **cpu**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa10d696-740a-458e-ae10-eec9a43fb362", + "metadata": {}, + "outputs": [], + "source": [ + "import onnx\n", + "import onnxruntime as ort" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "992524e5-2c2a-4e48-ae95-bd2aa87b72a9", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder-cpu.onnx" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc3fa967-73d5-431c-88a2-84b088aff06f", + "metadata": {}, + "outputs": [], + "source": [ + "datacube = {k: v.to(\"cpu\") for k, v in datacube.items()}\n", + "onnx_embedder = ort.InferenceSession(\n", + " \"clay-v1-encoder-cpu.onnx\", providers=[\"CPUExecutionProvider\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24591d17-d1c8-452b-9b20-676a9b6f8643", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "embeddings = onnx_embedder.run(\n", + " [],\n", + " {\n", + " \"cube\": datacube[\"pixels\"].numpy(),\n", + " \"time\": datacube[\"time\"].numpy(),\n", + " \"latlon\": datacube[\"latlon\"].numpy(),\n", + " \"waves\": datacube[\"waves\"].numpy(),\n", + " \"gsd\": datacube[\"gsd\"].numpy(),\n", + " },\n", + ")[0]\n", + "embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "9c07216e-a109-4cd8-8c74-9a3fc9a37757", + "metadata": {}, + "source": [ + "For each chip, we have an embedding of size `768`" + ] + }, + { + "cell_type": "markdown", + "id": "2e8d5900-9a4b-4e2d-b992-4fb0a1e8c835", + "metadata": {}, + "source": [ + "### Store the results\n", + "\n", + "We create a table containing the embeddings, bounding box, the STAC item ID, the datetime of the image capture, and the chip x and y ids. Then we save that data to disk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "677f04d3-db38-4d44-9b55-c103d54adcd5", + "metadata": {}, + "outputs": [], + "source": [ + "# Write data to pyarrow table\n", + "index = {\n", + " \"datetimes\": datetimes,\n", + " \"chip_ids\": chip_ids,\n", + " \"item_ids\": item_ids,\n", + " \"emeddings\": [np.ascontiguousarray(dat) for dat in embeddings],\n", + " \"geometry\": ga.as_geoarrow([dat.wkt for dat in bboxs]),\n", + "}\n", + "table = pa.table(index)\n", + "table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d62a9e8a-b4f9-491c-a437-6a164a9e74fe", + "metadata": {}, + "outputs": [], + "source": [ + "pq.write_table(table, \"embeddings.parquet\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d30fb8c7-d04d-453f-93f6-dc3599f1df15", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/utils.py b/src/utils.py index 539a2acd..b0f2bcce 100644 --- a/src/utils.py +++ b/src/utils.py @@ -24,6 +24,7 @@ def posemb_sincos_2d_with_gsd( y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + gsd = gsd.to(x.device) omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature ** (2 * omega / dim)) * (gsd / 1.0) # Adjusted for g @@ -33,16 +34,16 @@ def posemb_sincos_2d_with_gsd( return pe.type(dtype) -def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32): +def posemb_sincos_1d(waves, dim, temperature: int = 10000, dtype=torch.float32): assert ( dim % 2 == 0 ), "Feature dimension must be a multiple of 2 for sincos embedding" - pos = torch.arange(pos) if isinstance(pos, int) else pos + waves = torch.arange(waves) if isinstance(waves, int) else waves - omega = torch.arange(dim // 2) / (dim // 2 - 1) + omega = torch.arange(dim // 2, device=waves.device) / (dim // 2 - 1) omega = 1.0 / (temperature**omega) - scaled_pos = pos[:, None] * omega[None, :] - pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1) + scaled_waves = waves[:, None] * omega[None, :] + pe = torch.cat((scaled_waves.sin(), scaled_waves.cos()), dim=1) return pe.type(dtype)