Skip to content

Commit 0873539

Browse files
committed
Review suggestions for PR #247
1 parent f655c8c commit 0873539

File tree

1 file changed

+48
-43
lines changed

1 file changed

+48
-43
lines changed

nbs/v1-inference-simsearch-naip-stacchip.ipynb

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,17 @@
3636
"source": [
3737
"import sys\n",
3838
"\n",
39-
"sys.path.append(\"../model\")\n",
40-
"sys.path.insert(0, \"../stacchip\")"
39+
"sys.path.append(\"..\")"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"id": "eabd5bef",
46+
"metadata": {},
47+
"outputs": [],
48+
"source": [
49+
"%pip install stacchip==0.1.31"
4150
]
4251
},
4352
{
@@ -59,13 +68,12 @@
5968
"import numpy as np\n",
6069
"import pandas as pd\n",
6170
"import pystac_client\n",
62-
"import requests\n",
71+
"import rasterio\n",
6372
"import shapely\n",
6473
"import torch\n",
6574
"import yaml\n",
6675
"from box import Box\n",
6776
"from pyproj import Transformer\n",
68-
"from rasterio.io import MemoryFile\n",
6977
"from stacchip.chipper import Chipper\n",
7078
"from stacchip.indexer import NoStatsChipIndexer\n",
7179
"from stacchip.processors.prechip import normalize_timestamp\n",
@@ -145,20 +153,19 @@
145153
" Returns:\n",
146154
" tuple: Bounds coordinates and centroid coordinates.\n",
147155
" \"\"\"\n",
148-
" response = requests.get(url)\n",
149-
" response.raise_for_status()\n",
156+
" with rasterio.open(url) as rst:\n",
157+
" bounds = rst.bounds\n",
158+
" transformer = Transformer.from_crs(rst.crs, 4326)\n",
159+
"\n",
160+
" centroid_x = (bounds.left + bounds.right) / 2\n",
161+
" centroid_y = (bounds.top + bounds.bottom) / 2\n",
162+
"\n",
163+
" centroid_x, centroid_y = transformer.transform(centroid_x, centroid_y)\n",
150164
"\n",
151-
" with MemoryFile(response.content) as memfile:\n",
152-
" with memfile.open() as src:\n",
153-
" bounds = src.bounds\n",
154-
" transformer = Transformer.from_crs(src.crs, 4326)\n",
155-
" # Calculate centroid\n",
156-
" centroid_x = (bounds.left + bounds.right) / 2\n",
157-
" centroid_y = (bounds.top + bounds.bottom) / 2\n",
158-
" centroid_x, centroid_y = transformer.transform(centroid_x, centroid_y)\n",
159-
" bounds_b, bounds_l = transformer.transform(bounds.left, bounds.bottom)\n",
160-
" bounds_t, bounds_r = transformer.transform(bounds.right, bounds.top)\n",
161-
" return [bounds_b, bounds_l, bounds_t, bounds_r], centroid_x, centroid_y"
165+
" bounds_b, bounds_l = transformer.transform(bounds.left, bounds.bottom)\n",
166+
" bounds_t, bounds_r = transformer.transform(bounds.right, bounds.top)\n",
167+
"\n",
168+
" return [bounds_b, bounds_l, bounds_t, bounds_r], [centroid_x, centroid_y]"
162169
]
163170
},
164171
{
@@ -192,11 +199,13 @@
192199
" print(f\"Bounds coordinates: {bounds}, centroid coordinates: {centroid}\")\n",
193200
"\n",
194201
" # Instantiate the chipper\n",
195-
" chipper = Chipper(indexer, asset_blacklist=[\"metadata\"])\n",
202+
" chipper = Chipper(\n",
203+
" indexer, asset_blacklist=[\"thumbnail\", \"tilejson\", \"rendered_preview\"]\n",
204+
" )\n",
196205
"\n",
197206
" # Get 5 randomly sampled chips from the total\n",
198207
" # number of chips within this item's entire image\n",
199-
" for chip_id in random.sample(range(0, len(chipper)), 5):\n",
208+
" for chip_id in random.sample(range(0, len(chipper)), 25):\n",
200209
" chips.append(chipper[chip_id])\n",
201210
" chip_images.append(chipper[chip_id][\"image\"])\n",
202211
" chip_bounds.append(bounds)\n",
@@ -205,7 +214,7 @@
205214
},
206215
{
207216
"cell_type": "markdown",
208-
"id": "e1c80c88-b91a-474d-8f66-d830982e4e82",
217+
"id": "b61aaad7",
209218
"metadata": {},
210219
"source": [
211220
"Visualize a generated image chip."
@@ -408,7 +417,7 @@
408417
"metadata": {},
409418
"outputs": [],
410419
"source": [
411-
"outdir_embeddings = \"./embeddings/\"\n",
420+
"outdir_embeddings = \"../data/embeddings/\"\n",
412421
"os.makedirs(outdir_embeddings, exist_ok=True)"
413422
]
414423
},
@@ -427,11 +436,12 @@
427436
"metadata": {},
428437
"outputs": [],
429438
"source": [
430-
"# Load the pretrained model\n",
439+
"# Download the pretrained model from\n",
440+
"# https://huggingface.co/made-with-clay/Clay/blob/main/clay-v1-base.ckpt\n",
441+
"# and put it in a checkpoints folder.\n",
431442
"model = load_model(\n",
432-
" # ckpt=\"s3://clay-model-ckpt/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\",\n",
433-
" ckpt=\"../checkpoints/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\",\n",
434-
" device=\"cuda\",\n",
443+
" ckpt=\"../checkpoints/clay-v1-base.ckpt\",\n",
444+
" device=torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\"),\n",
435445
")"
436446
]
437447
},
@@ -495,17 +505,7 @@
495505
"metadata": {},
496506
"outputs": [],
497507
"source": [
498-
"embeddings[0].shape"
499-
]
500-
},
501-
{
502-
"cell_type": "code",
503-
"execution_count": null,
504-
"id": "1bf9e1fc-b432-4ec3-a028-f85d2ff57469",
505-
"metadata": {},
506-
"outputs": [],
507-
"source": [
508-
"len(embeddings)"
508+
"print(f\"Created {len(embeddings)} embeddings of shape {embeddings[0].shape[1]}\")"
509509
]
510510
},
511511
{
@@ -525,7 +525,7 @@
525525
"outputs": [],
526526
"source": [
527527
"# Connect to the embeddings database\n",
528-
"db = lancedb.connect(\"embeddings\")"
528+
"db = lancedb.connect(outdir_embeddings)"
529529
]
530530
},
531531
{
@@ -583,7 +583,10 @@
583583
"outputs": [],
584584
"source": [
585585
"# Drop existing table if any\n",
586-
"db.drop_table(\"clay-v001\")\n",
586+
"try:\n",
587+
" db.drop_table(\"clay-v001\")\n",
588+
"except FileNotFoundError:\n",
589+
" pass\n",
587590
"db.table_names()"
588591
]
589592
},
@@ -618,7 +621,8 @@
618621
"outputs": [],
619622
"source": [
620623
"# Perform the search\n",
621-
"result = tbl.search(query=v).limit(4).to_pandas()"
624+
"search_x_images = 6\n",
625+
"result = tbl.search(query=v).limit(search_x_images).to_pandas()"
622626
]
623627
},
624628
{
@@ -638,7 +642,7 @@
638642
"metadata": {},
639643
"outputs": [],
640644
"source": [
641-
"def plot(df, cols=4):\n",
645+
"def plot(df, cols=4, save=False):\n",
642646
" \"\"\"\n",
643647
" Plot the top similar images.\n",
644648
"\n",
@@ -656,7 +660,8 @@
656660
" ax.set_title(f\"{row['idx']}\")\n",
657661
" i += 1\n",
658662
" plt.tight_layout()\n",
659-
" fig.savefig(\"similar.png\")"
663+
" if save:\n",
664+
" fig.savefig(\"similar.png\")"
660665
]
661666
},
662667
{
@@ -667,7 +672,7 @@
667672
"outputs": [],
668673
"source": [
669674
"# Plot the top similar images\n",
670-
"plot(result)"
675+
"plot(result, search_x_images)"
671676
]
672677
},
673678
{
@@ -695,7 +700,7 @@
695700
"name": "python",
696701
"nbconvert_exporter": "python",
697702
"pygments_lexer": "ipython3",
698-
"version": "3.9.0"
703+
"version": "3.11.8"
699704
}
700705
},
701706
"nbformat": 4,

0 commit comments

Comments
 (0)