|
36 | 36 | "source": [ |
37 | 37 | "import sys\n", |
38 | 38 | "\n", |
39 | | - "sys.path.append(\"../model\")\n", |
40 | | - "sys.path.insert(0, \"../stacchip\")" |
| 39 | + "sys.path.append(\"..\")" |
| 40 | + ] |
| 41 | + }, |
| 42 | + { |
| 43 | + "cell_type": "code", |
| 44 | + "execution_count": null, |
| 45 | + "id": "eabd5bef", |
| 46 | + "metadata": {}, |
| 47 | + "outputs": [], |
| 48 | + "source": [ |
| 49 | + "%pip install stacchip==0.1.31" |
41 | 50 | ] |
42 | 51 | }, |
43 | 52 | { |
|
59 | 68 | "import numpy as np\n", |
60 | 69 | "import pandas as pd\n", |
61 | 70 | "import pystac_client\n", |
62 | | - "import requests\n", |
| 71 | + "import rasterio\n", |
63 | 72 | "import shapely\n", |
64 | 73 | "import torch\n", |
65 | 74 | "import yaml\n", |
66 | 75 | "from box import Box\n", |
67 | 76 | "from pyproj import Transformer\n", |
68 | | - "from rasterio.io import MemoryFile\n", |
69 | 77 | "from stacchip.chipper import Chipper\n", |
70 | 78 | "from stacchip.indexer import NoStatsChipIndexer\n", |
71 | 79 | "from stacchip.processors.prechip import normalize_timestamp\n", |
|
145 | 153 | " Returns:\n", |
146 | 154 | " tuple: Bounds coordinates and centroid coordinates.\n", |
147 | 155 | " \"\"\"\n", |
148 | | - " response = requests.get(url)\n", |
149 | | - " response.raise_for_status()\n", |
| 156 | + " with rasterio.open(url) as rst:\n", |
| 157 | + " bounds = rst.bounds\n", |
| 158 | + " transformer = Transformer.from_crs(rst.crs, 4326)\n", |
| 159 | + "\n", |
| 160 | + " centroid_x = (bounds.left + bounds.right) / 2\n", |
| 161 | + " centroid_y = (bounds.top + bounds.bottom) / 2\n", |
| 162 | + "\n", |
| 163 | + " centroid_x, centroid_y = transformer.transform(centroid_x, centroid_y)\n", |
150 | 164 | "\n", |
151 | | - " with MemoryFile(response.content) as memfile:\n", |
152 | | - " with memfile.open() as src:\n", |
153 | | - " bounds = src.bounds\n", |
154 | | - " transformer = Transformer.from_crs(src.crs, 4326)\n", |
155 | | - " # Calculate centroid\n", |
156 | | - " centroid_x = (bounds.left + bounds.right) / 2\n", |
157 | | - " centroid_y = (bounds.top + bounds.bottom) / 2\n", |
158 | | - " centroid_x, centroid_y = transformer.transform(centroid_x, centroid_y)\n", |
159 | | - " bounds_b, bounds_l = transformer.transform(bounds.left, bounds.bottom)\n", |
160 | | - " bounds_t, bounds_r = transformer.transform(bounds.right, bounds.top)\n", |
161 | | - " return [bounds_b, bounds_l, bounds_t, bounds_r], centroid_x, centroid_y" |
| 165 | + " bounds_b, bounds_l = transformer.transform(bounds.left, bounds.bottom)\n", |
| 166 | + " bounds_t, bounds_r = transformer.transform(bounds.right, bounds.top)\n", |
| 167 | + "\n", |
| 168 | + " return [bounds_b, bounds_l, bounds_t, bounds_r], [centroid_x, centroid_y]" |
162 | 169 | ] |
163 | 170 | }, |
164 | 171 | { |
|
192 | 199 | " print(f\"Bounds coordinates: {bounds}, centroid coordinates: {centroid}\")\n", |
193 | 200 | "\n", |
194 | 201 | " # Instantiate the chipper\n", |
195 | | - " chipper = Chipper(indexer, asset_blacklist=[\"metadata\"])\n", |
| 202 | + " chipper = Chipper(\n", |
| 203 | + " indexer, asset_blacklist=[\"thumbnail\", \"tilejson\", \"rendered_preview\"]\n", |
| 204 | + " )\n", |
196 | 205 | "\n", |
197 | 206 | " # Get 5 randomly sampled chips from the total\n", |
198 | 207 | " # number of chips within this item's entire image\n", |
199 | | - " for chip_id in random.sample(range(0, len(chipper)), 5):\n", |
| 208 | + " for chip_id in random.sample(range(0, len(chipper)), 25):\n", |
200 | 209 | " chips.append(chipper[chip_id])\n", |
201 | 210 | " chip_images.append(chipper[chip_id][\"image\"])\n", |
202 | 211 | " chip_bounds.append(bounds)\n", |
|
205 | 214 | }, |
206 | 215 | { |
207 | 216 | "cell_type": "markdown", |
208 | | - "id": "e1c80c88-b91a-474d-8f66-d830982e4e82", |
| 217 | + "id": "b61aaad7", |
209 | 218 | "metadata": {}, |
210 | 219 | "source": [ |
211 | 220 | "Visualize a generated image chip." |
|
408 | 417 | "metadata": {}, |
409 | 418 | "outputs": [], |
410 | 419 | "source": [ |
411 | | - "outdir_embeddings = \"./embeddings/\"\n", |
| 420 | + "outdir_embeddings = \"../data/embeddings/\"\n", |
412 | 421 | "os.makedirs(outdir_embeddings, exist_ok=True)" |
413 | 422 | ] |
414 | 423 | }, |
|
427 | 436 | "metadata": {}, |
428 | 437 | "outputs": [], |
429 | 438 | "source": [ |
430 | | - "# Load the pretrained model\n", |
| 439 | + "# Download the pretrained model from\n", |
| 440 | + "# https://huggingface.co/made-with-clay/Clay/blob/main/clay-v1-base.ckpt\n", |
| 441 | + "# and put it in a checkpoints folder.\n", |
431 | 442 | "model = load_model(\n", |
432 | | - " # ckpt=\"s3://clay-model-ckpt/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\",\n", |
433 | | - " ckpt=\"../checkpoints/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\",\n", |
434 | | - " device=\"cuda\",\n", |
| 443 | + " ckpt=\"../checkpoints/clay-v1-base.ckpt\",\n", |
| 444 | + " device=torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\"),\n", |
435 | 445 | ")" |
436 | 446 | ] |
437 | 447 | }, |
|
495 | 505 | "metadata": {}, |
496 | 506 | "outputs": [], |
497 | 507 | "source": [ |
498 | | - "embeddings[0].shape" |
499 | | - ] |
500 | | - }, |
501 | | - { |
502 | | - "cell_type": "code", |
503 | | - "execution_count": null, |
504 | | - "id": "1bf9e1fc-b432-4ec3-a028-f85d2ff57469", |
505 | | - "metadata": {}, |
506 | | - "outputs": [], |
507 | | - "source": [ |
508 | | - "len(embeddings)" |
| 508 | + "print(f\"Created {len(embeddings)} embeddings of shape {embeddings[0].shape[1]}\")" |
509 | 509 | ] |
510 | 510 | }, |
511 | 511 | { |
|
525 | 525 | "outputs": [], |
526 | 526 | "source": [ |
527 | 527 | "# Connect to the embeddings database\n", |
528 | | - "db = lancedb.connect(\"embeddings\")" |
| 528 | + "db = lancedb.connect(outdir_embeddings)" |
529 | 529 | ] |
530 | 530 | }, |
531 | 531 | { |
|
583 | 583 | "outputs": [], |
584 | 584 | "source": [ |
585 | 585 | "# Drop existing table if any\n", |
586 | | - "db.drop_table(\"clay-v001\")\n", |
| 586 | + "try:\n", |
| 587 | + " db.drop_table(\"clay-v001\")\n", |
| 588 | + "except FileNotFoundError:\n", |
| 589 | + " pass\n", |
587 | 590 | "db.table_names()" |
588 | 591 | ] |
589 | 592 | }, |
|
618 | 621 | "outputs": [], |
619 | 622 | "source": [ |
620 | 623 | "# Perform the search\n", |
621 | | - "result = tbl.search(query=v).limit(4).to_pandas()" |
| 624 | + "search_x_images = 6\n", |
| 625 | + "result = tbl.search(query=v).limit(search_x_images).to_pandas()" |
622 | 626 | ] |
623 | 627 | }, |
624 | 628 | { |
|
638 | 642 | "metadata": {}, |
639 | 643 | "outputs": [], |
640 | 644 | "source": [ |
641 | | - "def plot(df, cols=4):\n", |
| 645 | + "def plot(df, cols=4, save=False):\n", |
642 | 646 | " \"\"\"\n", |
643 | 647 | " Plot the top similar images.\n", |
644 | 648 | "\n", |
|
656 | 660 | " ax.set_title(f\"{row['idx']}\")\n", |
657 | 661 | " i += 1\n", |
658 | 662 | " plt.tight_layout()\n", |
659 | | - " fig.savefig(\"similar.png\")" |
| 663 | + " if save:\n", |
| 664 | + " fig.savefig(\"similar.png\")" |
660 | 665 | ] |
661 | 666 | }, |
662 | 667 | { |
|
667 | 672 | "outputs": [], |
668 | 673 | "source": [ |
669 | 674 | "# Plot the top similar images\n", |
670 | | - "plot(result)" |
| 675 | + "plot(result, search_x_images)" |
671 | 676 | ] |
672 | 677 | }, |
673 | 678 | { |
|
695 | 700 | "name": "python", |
696 | 701 | "nbconvert_exporter": "python", |
697 | 702 | "pygments_lexer": "ipython3", |
698 | | - "version": "3.9.0" |
| 703 | + "version": "3.11.8" |
699 | 704 | } |
700 | 705 | }, |
701 | 706 | "nbformat": 4, |
|
0 commit comments