Skip to content

Commit 979cffe

Browse files
authored
Sg/revisions for usability (#4)
* updated pip install to toml * updated instructions and scripting * fix to toml * recompute of inference is * naming fix for inference experiment * fix to oom on probabilities
1 parent e729a1d commit 979cffe

File tree

7 files changed

+149
-112
lines changed

7 files changed

+149
-112
lines changed

project/geodata-3d-conditional/inference_demo.ipynb

Lines changed: 78 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,12 @@
140140
{
141141
"data": {
142142
"application/vnd.jupyter.widget-view+json": {
143-
"model_id": "2398bdf49d20427898f2eb0cc16783ef",
143+
"model_id": "6f8261b149204a60a37db193813dc6b6",
144144
"version_major": 2,
145145
"version_minor": 0
146146
},
147147
"text/plain": [
148-
"Widget(value='<iframe src=\"http://localhost:53764/index.html?ui=P_0x37992d550_0&reconnect=auto\" class=\"pyvista"
148+
"Widget(value='<iframe src=\"http://localhost:36623/index.html?ui=P_0x7d484a41c1a0_0&reconnect=auto\" class=\"pyvi"
149149
]
150150
},
151151
"metadata": {},
@@ -204,30 +204,45 @@
204204
"source": [
205205
"An auto-populating function is provided that \n",
206206
"\n",
207-
"1. Iterates through a folder `save_dir` containing subfolders `cond_data_folder_title` with conditional data `boreholes.pt` and `true_model.pt`\n",
208-
"2. Creates the conditional data that includes surface, air, and boreholes from `boreholes.pt` and `true_model.pt`\n",
207+
"1. Iterates through a folder `save_dir` containing subfolders `cond_data_folder_title` with paired data `boreholes.pt` and `true_model.pt` containing the boreholes extracted from the ground truth geological model.\n",
208+
"2. Creates the conditional data for the inverse problem that includes surface, air, and boreholes from `boreholes.pt` and `true_model.pt`\n",
209209
"3. Runs the inference routine on the data to produce `n_samples_each` for each set of conditional data\n",
210-
"4. Saves the solutions in the same subfolder with `sample_title_000.pt` naming convention"
210+
"4. Saves the solutions in the same subfolder with `sample_title_000.pt` naming convention\n",
211+
"\n",
212+
"The script below will sample 9 conditional reconstructions for each pair of boreholes with true model. (The true model is only used to get surface and air data, subsurface is not used in the inference)\n",
213+
"\n",
214+
"The sample time is long, so precomputed inference results available for demonstration of ensemble analysis below. To run the inference locally, set `USE_PRECOMPUTED_INFERENCE_RESULTS = False` below."
211215
]
212216
},
213217
{
214218
"cell_type": "code",
215219
"execution_count": 7,
220+
"id": "1dc6a8c0",
221+
"metadata": {},
222+
"outputs": [],
223+
"source": [
224+
"USE_PRECOMPUTED_INFERENCE_RESULTS = True"
225+
]
226+
},
227+
{
228+
"cell_type": "code",
229+
"execution_count": 8,
216230
"id": "ba241280",
217231
"metadata": {},
218232
"outputs": [],
219233
"source": [
220234
"from model_inference_experiments import populate_solutions\n",
221235
"\n",
222-
"# populate_solutions(\n",
223-
"# save_dir=save_dir,\n",
224-
"# cond_data_folder_title=cond_data_folder_title,\n",
225-
"# device=device,\n",
226-
"# model=flowmatching_model,\n",
227-
"# n_samples_each=9,\n",
228-
"# batch_size=1,\n",
229-
"# sample_title=\"sample\",\n",
230-
"# )"
236+
"if not USE_PRECOMPUTED_INFERENCE_RESULTS:\n",
237+
" populate_solutions(\n",
238+
" save_dir=save_dir,\n",
239+
" cond_data_folder_title=cond_data_folder_title,\n",
240+
" device=device,\n",
241+
" model=flowmatching_model,\n",
242+
" n_samples_each=9,\n",
243+
" batch_size=1,\n",
244+
" sample_title=\"sample\",\n",
245+
" )"
231246
]
232247
},
233248
{
@@ -240,22 +255,23 @@
240255
},
241256
{
242257
"cell_type": "code",
243-
"execution_count": null,
258+
"execution_count": 9,
244259
"id": "10058c0e",
245260
"metadata": {},
246261
"outputs": [],
247262
"source": [
248263
"from model_inference_experiments import load_solutions, show_solutions\n",
249264
"\n",
250-
"# Same folder as the stored conditional data\n",
251-
"sample_number = 0\n",
252-
"samples_dir = os.path.join(save_dir, f\"{cond_data_folder_title}_{sample_number}\")\n",
253-
"print(\"Loading from:\", samples_dir)\n",
254-
"# Autoparse the true_model.pt, boreholes.pt, and any solutions\n",
255-
"geomodel, boreholes = load_model_and_boreholes(samples_dir)\n",
256-
"solutions = load_solutions(samples_dir, sample_title=\"sample\")\n",
257-
"show_model_and_boreholes(geomodel, boreholes)\n",
258-
"show_solutions(solutions)"
265+
"if not USE_PRECOMPUTED_INFERENCE_RESULTS:\n",
266+
" # Same folder as the stored conditional data\n",
267+
" sample_number = 0\n",
268+
" samples_dir = os.path.join(save_dir, f\"{cond_data_folder_title}_{sample_number}\")\n",
269+
" print(\"Loading from:\", samples_dir)\n",
270+
" # Autoparse the true_model.pt, boreholes.pt, and any solutions\n",
271+
" geomodel, boreholes = load_model_and_boreholes(samples_dir)\n",
272+
" solutions = load_solutions(samples_dir, sample_title=\"sample\")\n",
273+
" show_model_and_boreholes(geomodel, boreholes)\n",
274+
" show_solutions(solutions)"
259275
]
260276
},
261277
{
@@ -269,7 +285,7 @@
269285
},
270286
{
271287
"cell_type": "code",
272-
"execution_count": 9,
288+
"execution_count": 10,
273289
"id": "2301054e",
274290
"metadata": {},
275291
"outputs": [],
@@ -329,15 +345,15 @@
329345
},
330346
{
331347
"cell_type": "code",
332-
"execution_count": 10,
348+
"execution_count": 11,
333349
"id": "145d861c",
334350
"metadata": {},
335351
"outputs": [
336352
{
337353
"name": "stdout",
338354
"output_type": "stream",
339355
"text": [
340-
"Restored to: /Users/sghyseli/Projects/synthgeo-paper/flowtrain_stochastic_interpolation/project/geodata-3d-conditional/samples/jupyter-demo/paper_cond_gen_0\n"
356+
"Restored to: /home/sghys/projects/flowtrain_stochastic_interpolation/project/geodata-3d-conditional/samples/jupyter-demo/paper_cond_gen_0\n"
341357
]
342358
}
343359
],
@@ -360,19 +376,19 @@
360376
},
361377
{
362378
"cell_type": "code",
363-
"execution_count": 11,
379+
"execution_count": 12,
364380
"id": "3afe5e1e",
365381
"metadata": {},
366382
"outputs": [
367383
{
368384
"data": {
369385
"application/vnd.jupyter.widget-view+json": {
370-
"model_id": "efbf0829bf2342e4bec7abe76ac086d8",
386+
"model_id": "0e02af847ae94161aebb114ee640b10b",
371387
"version_major": 2,
372388
"version_minor": 0
373389
},
374390
"text/plain": [
375-
"Widget(value='<iframe src=\"http://localhost:53764/index.html?ui=P_0x34fb582f0_1&reconnect=auto\" class=\"pyvista"
391+
"Widget(value='<iframe src=\"http://localhost:36623/index.html?ui=P_0x7d484de739b0_1&reconnect=auto\" class=\"pyvi"
376392
]
377393
},
378394
"metadata": {},
@@ -381,12 +397,12 @@
381397
{
382398
"data": {
383399
"application/vnd.jupyter.widget-view+json": {
384-
"model_id": "86bcac6e346a4d2eae3e1af847d1edd5",
400+
"model_id": "0d1d6198cffd40a78fc1fa18e19e6a53",
385401
"version_major": 2,
386402
"version_minor": 0
387403
},
388404
"text/plain": [
389-
"Widget(value='<iframe src=\"http://localhost:53764/index.html?ui=P_0x34def0260_2&reconnect=auto\" class=\"pyvista"
405+
"Widget(value='<iframe src=\"http://localhost:36623/index.html?ui=P_0x7d484e2fa9f0_2&reconnect=auto\" class=\"pyvi"
390406
]
391407
},
392408
"metadata": {},
@@ -395,45 +411,48 @@
395411
],
396412
"source": [
397413
"sample_number = 0\n",
398-
"geomodel, boreholes = load_model_and_boreholes(samples_dir, device=device)\n",
399-
"solutions = load_solutions(samples_dir, sample_title=\"sample\", device=device)\n",
414+
"geomodel, boreholes = load_model_and_boreholes(samples_dir, device=\"cpu\")\n",
415+
"solutions = load_solutions(samples_dir, sample_title=\"sample\", device=\"cpu\")\n",
400416
"show_model_and_boreholes(geomodel, boreholes)\n",
401417
"# Limit to 10 solutions for display\n",
402418
"show_solutions(solutions[0:10])"
403419
]
404420
},
405421
{
406422
"cell_type": "code",
407-
"execution_count": 14,
423+
"execution_count": 15,
408424
"id": "4e45a77a",
409425
"metadata": {},
410426
"outputs": [],
411427
"source": [
412-
"def vote_probabilities(\n",
413-
" solutions: torch.Tensor, num_categories: int = 15\n",
414-
") -> torch.Tensor:\n",
428+
"def vote_probabilities(solutions: torch.Tensor, num_categories: int = 15) -> torch.Tensor:\n",
415429
" \"\"\"\n",
416-
" Compute per-voxel class probabilities by majority vote across the batch.\n",
417-
" Input: [B,X,Y,Z] of categories and Output: [C,X,Y,Z] of probabilities\n",
430+
" Compute per-voxel class probabilities over a batch.\n",
431+
" Input: [B, X, Y, Z] integer categories (may include -1)\n",
432+
" Output: [C, X, Y, Z] float probabilities\n",
418433
" \"\"\"\n",
419434
" assert solutions.dim() == 4\n",
420435
" B, X, Y, Z = solutions.shape\n",
436+
" device = solutions.device\n",
421437
"\n",
422-
" # Shift labels to 0..C-1 if they are -1..C-2\n",
438+
" # Handle negative indices (-1 for \"air\")\n",
423439
" if solutions.min().item() < 0:\n",
424-
" sol_shifted = solutions + 1\n",
425-
" else:\n",
426-
" sol_shifted = solutions\n",
427-
" sol_shifted = sol_shifted.to(torch.long) # required by bincount\n",
440+
" solutions = solutions + 1 # shift to 0..C-1\n",
441+
"\n",
442+
" solutions = solutions.to(torch.long)\n",
428443
"\n",
429-
" sols_one_hot = (\n",
430-
" torch.nn.functional.one_hot(sol_shifted, num_categories)\n",
431-
" .permute(0, 4, 1, 2, 3)\n",
432-
" .float()\n",
433-
" ) # [B, 15, 64, 64, 64]\n",
434-
" probability_vector = sols_one_hot.mean(dim=0, keepdim=False)\n",
444+
" # Accumulator for per-class voxel counts\n",
445+
" accumulator = torch.zeros(num_categories, X, Y, Z, dtype=torch.float32, device=device)\n",
435446
"\n",
436-
" return probability_vector\n",
447+
" # Accumulate one-hot for each sample\n",
448+
" for b in range(B):\n",
449+
" one_hot = torch.nn.functional.one_hot(solutions[b], num_classes=num_categories) # [X, Y, Z, C]\n",
450+
" one_hot = one_hot.permute(3, 0, 1, 2).float() # [C, X, Y, Z]\n",
451+
" accumulator += one_hot\n",
452+
"\n",
453+
" # Normalize by total samples\n",
454+
" probabilities = accumulator / B\n",
455+
" return probabilities\n",
437456
"\n",
438457
"\n",
439458
"solution_probabilistic = vote_probabilities(solutions, num_categories=15)"
@@ -450,19 +469,19 @@
450469
},
451470
{
452471
"cell_type": "code",
453-
"execution_count": 15,
472+
"execution_count": 16,
454473
"id": "7c67d1ec",
455474
"metadata": {},
456475
"outputs": [
457476
{
458477
"data": {
459478
"application/vnd.jupyter.widget-view+json": {
460-
"model_id": "f9877221328e42079086c8e85e0aa00f",
479+
"model_id": "2fbba51d6ac146919348db865e5aad3d",
461480
"version_major": 2,
462481
"version_minor": 0
463482
},
464483
"text/plain": [
465-
"Widget(value='<iframe src=\"http://localhost:53764/index.html?ui=P_0x38e7f0260_5&reconnect=auto\" class=\"pyvista"
484+
"Widget(value='<iframe src=\"http://localhost:36623/index.html?ui=P_0x7d483cb153d0_3&reconnect=auto\" class=\"pyvi"
466485
]
467486
},
468487
"metadata": {},
@@ -471,12 +490,12 @@
471490
{
472491
"data": {
473492
"application/vnd.jupyter.widget-view+json": {
474-
"model_id": "eb98d17d3b954dd88d1abb5a16356fe2",
493+
"model_id": "5ba813e5a3b341ab86744048d92adc05",
475494
"version_major": 2,
476495
"version_minor": 0
477496
},
478497
"text/plain": [
479-
"Widget(value='<iframe src=\"http://localhost:53764/index.html?ui=P_0x38e831760_6&reconnect=auto\" class=\"pyvista"
498+
"Widget(value='<iframe src=\"http://localhost:36623/index.html?ui=P_0x7d486e937620_4&reconnect=auto\" class=\"pyvi"
480499
]
481500
},
482501
"metadata": {},
@@ -621,7 +640,7 @@
621640
],
622641
"metadata": {
623642
"kernelspec": {
624-
"display_name": "geopaper",
643+
"display_name": "ml",
625644
"language": "python",
626645
"name": "python3"
627646
},
@@ -635,7 +654,7 @@
635654
"name": "python",
636655
"nbconvert_exporter": "python",
637656
"pygments_lexer": "ipython3",
638-
"version": "3.12.11"
657+
"version": "3.11.9"
639658
}
640659
},
641660
"nbformat": 4,

project/geodata-3d-conditional/model_train_sh_inference_cond.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
import argparse
22
import os
33
import platform
4-
import time
5-
import warnings
64
from typing import Any, Dict, List, Tuple, Optional
7-
from functools import partial
85

9-
import matplotlib.pyplot as plt
10-
import numpy as np
11-
import seaborn as sns
126
import torch
13-
import json
147

158
# from cpu_binding import affinity, num_threads
169
# if affinity: # https://github.com/pytorch/pytorch/issues/99625
@@ -21,10 +14,7 @@
2114

2215
import torch.nn as nn
2316
import torch.nn.functional as F
24-
import wandb
25-
from matplotlib import patches
2617
from torch.utils.data import DataLoader
27-
from tqdm import tqdm
2818

2919
# Third-party libraries
3020
from lightning import Trainer

project/geodata-3d-unconditional/model_train_inference.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,14 @@
55
import argparse
66
import os
77
import re
8-
import platform
8+
99
import time
1010
import warnings
1111
from typing import Any, Dict, List, Tuple, Optional, Union
1212

13-
import matplotlib.pyplot as plt
14-
import numpy as np
15-
import seaborn as sns
1613
import torch
1714
import torch.nn as nn
1815
import torch.nn.functional as F
19-
import wandb
20-
from matplotlib import patches
2116
from torch.utils.data import DataLoader
2217
from tqdm import tqdm
2318

@@ -320,7 +315,7 @@ def __init__(
320315
# Embedding layer setup
321316
self.embedding = nn.Embedding(self.num_categories, self.embedding_dim)
322317
self._initialize_embedding(self.num_categories, self.embedding_dim)
323-
# Freeze embedding weights after initialization (non-learnable)
318+
# Freeze embedding weights after initialization (non-learnable hardcoding, set to True for learnable)
324319
self.embedding.weight.requires_grad = False
325320

326321
# Update model_params to reflect the new input channels
@@ -619,6 +614,7 @@ def run_inference(
619614

620615
solver = ODEFlowSolver(model=model.net, rtol=1e-6)
621616

617+
# Start and stop times for ODEFlow, slightly away from t=0 to avoid numerical stability issues
622618
t0, tf = 0.001, 1.0
623619
n_steps = 16
624620

@@ -777,9 +773,10 @@ def parse_arguments():
777773
)
778774

779775
parser.add_argument(
780-
'--save-images',
781-
action='store_true',
782-
help='Save visualization images during inference'
776+
'--save-images',
777+
action=argparse.BooleanOptionalAction,
778+
default=True,
779+
help='Save visualization images during inference (use --no-save-images to disable)'
783780
)
784781

785782
parser.add_argument(
@@ -832,6 +829,10 @@ def main() -> None:
832829
model = Geo3DStochInterp.load_from_checkpoint(
833830
checkpoint_path, map_location=inference_device
834831
).to(inference_device)
832+
833+
print(f"Running inference with {args.n_samples} samples on device {inference_device} with batch size {args.batch_size}")
834+
print(f"Samples will be saved to: {dirs['samples_dir']}")
835+
print(f"Images will be saved to: {dirs['photo_dir'] if args.save_images else 'Not saving images'}")
835836

836837
run_inference(
837838
dirs,

project/geodata-3d-unconditional/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def download_if_missing(path, url):
2525
if not os.path.exists(path):
2626
os.makedirs(os.path.dirname(path), exist_ok=True)
2727
print(f"Downloading weights from {url}...")
28-
urllib.request.urlretrieve(url, path)
28+
torch.hub.download_url_to_file(url, path, progress=True)
2929
print("Download complete.")
3030

3131

0 commit comments

Comments
 (0)