|
140 | 140 | { |
141 | 141 | "data": { |
142 | 142 | "application/vnd.jupyter.widget-view+json": { |
143 | | - "model_id": "2398bdf49d20427898f2eb0cc16783ef", |
| 143 | + "model_id": "6f8261b149204a60a37db193813dc6b6", |
144 | 144 | "version_major": 2, |
145 | 145 | "version_minor": 0 |
146 | 146 | }, |
147 | 147 | "text/plain": [ |
148 | | - "Widget(value='<iframe src=\"http://localhost:53764/index.html?ui=P_0x37992d550_0&reconnect=auto\" class=\"pyvista…" |
| 148 | + "Widget(value='<iframe src=\"http://localhost:36623/index.html?ui=P_0x7d484a41c1a0_0&reconnect=auto\" class=\"pyvi…" |
149 | 149 | ] |
150 | 150 | }, |
151 | 151 | "metadata": {}, |
|
204 | 204 | "source": [ |
205 | 205 | "An auto-populating function is provided that \n", |
206 | 206 | "\n", |
207 | | - "1. Iterates through a folder `save_dir` containing subfolders `cond_data_folder_title` with conditional data `boreholes.pt` and `true_model.pt`\n", |
208 | | - "2. Creates the conditional data that includes surface, air, and boreholes from `boreholes.pt` and `true_model.pt`\n", |
| 207 | + "1. Iterates through a folder `save_dir` containing subfolders `cond_data_folder_title` with paired data `boreholes.pt` and `true_model.pt` containing the boreholes extracted from the ground truth geological model.\n", |
| 208 | + "2. Creates the conditional data for the inverse problem that includes surface, air, and boreholes from `boreholes.pt` and `true_model.pt`\n", |
209 | 209 | "3. Runs the inference routine on the data to produce `n_samples_each` for each set of conditional data\n", |
210 | | - "4. Saves the solutions in the same subfolder with `sample_title_000.pt` naming convention" |
| 210 | + "4. Saves the solutions in the same subfolder with `sample_title_000.pt` naming convention\n", |
| 211 | + "\n", |
| 212 | + "The script below will sample 9 conditional reconstructions for each pair of boreholes with true model. (The true model is only used to get surface and air data, subsurface is not used in the inference)\n", |
| 213 | + "\n", |
| 214 | + "The sample time is long, so precomputed inference results available for demonstration of ensemble analysis below. To run the inference locally, set `USE_PRECOMPUTED_INFERENCE_RESULTS = False` below." |
211 | 215 | ] |
212 | 216 | }, |
213 | 217 | { |
214 | 218 | "cell_type": "code", |
215 | 219 | "execution_count": 7, |
| 220 | + "id": "1dc6a8c0", |
| 221 | + "metadata": {}, |
| 222 | + "outputs": [], |
| 223 | + "source": [ |
| 224 | + "USE_PRECOMPUTED_INFERENCE_RESULTS = True" |
| 225 | + ] |
| 226 | + }, |
| 227 | + { |
| 228 | + "cell_type": "code", |
| 229 | + "execution_count": 8, |
216 | 230 | "id": "ba241280", |
217 | 231 | "metadata": {}, |
218 | 232 | "outputs": [], |
219 | 233 | "source": [ |
220 | 234 | "from model_inference_experiments import populate_solutions\n", |
221 | 235 | "\n", |
222 | | - "# populate_solutions(\n", |
223 | | - "# save_dir=save_dir,\n", |
224 | | - "# cond_data_folder_title=cond_data_folder_title,\n", |
225 | | - "# device=device,\n", |
226 | | - "# model=flowmatching_model,\n", |
227 | | - "# n_samples_each=9,\n", |
228 | | - "# batch_size=1,\n", |
229 | | - "# sample_title=\"sample\",\n", |
230 | | - "# )" |
| 236 | + "if not USE_PRECOMPUTED_INFERENCE_RESULTS:\n", |
| 237 | + " populate_solutions(\n", |
| 238 | + " save_dir=save_dir,\n", |
| 239 | + " cond_data_folder_title=cond_data_folder_title,\n", |
| 240 | + " device=device,\n", |
| 241 | + " model=flowmatching_model,\n", |
| 242 | + " n_samples_each=9,\n", |
| 243 | + " batch_size=1,\n", |
| 244 | + " sample_title=\"sample\",\n", |
| 245 | + " )" |
231 | 246 | ] |
232 | 247 | }, |
233 | 248 | { |
|
240 | 255 | }, |
241 | 256 | { |
242 | 257 | "cell_type": "code", |
243 | | - "execution_count": null, |
| 258 | + "execution_count": 9, |
244 | 259 | "id": "10058c0e", |
245 | 260 | "metadata": {}, |
246 | 261 | "outputs": [], |
247 | 262 | "source": [ |
248 | 263 | "from model_inference_experiments import load_solutions, show_solutions\n", |
249 | 264 | "\n", |
250 | | - "# Same folder as the stored conditional data\n", |
251 | | - "sample_number = 0\n", |
252 | | - "samples_dir = os.path.join(save_dir, f\"{cond_data_folder_title}_{sample_number}\")\n", |
253 | | - "print(\"Loading from:\", samples_dir)\n", |
254 | | - "# Autoparse the true_model.pt, boreholes.pt, and any solutions\n", |
255 | | - "geomodel, boreholes = load_model_and_boreholes(samples_dir)\n", |
256 | | - "solutions = load_solutions(samples_dir, sample_title=\"sample\")\n", |
257 | | - "show_model_and_boreholes(geomodel, boreholes)\n", |
258 | | - "show_solutions(solutions)" |
| 265 | + "if not USE_PRECOMPUTED_INFERENCE_RESULTS:\n", |
| 266 | + " # Same folder as the stored conditional data\n", |
| 267 | + " sample_number = 0\n", |
| 268 | + " samples_dir = os.path.join(save_dir, f\"{cond_data_folder_title}_{sample_number}\")\n", |
| 269 | + " print(\"Loading from:\", samples_dir)\n", |
| 270 | + " # Autoparse the true_model.pt, boreholes.pt, and any solutions\n", |
| 271 | + " geomodel, boreholes = load_model_and_boreholes(samples_dir)\n", |
| 272 | + " solutions = load_solutions(samples_dir, sample_title=\"sample\")\n", |
| 273 | + " show_model_and_boreholes(geomodel, boreholes)\n", |
| 274 | + " show_solutions(solutions)" |
259 | 275 | ] |
260 | 276 | }, |
261 | 277 | { |
|
269 | 285 | }, |
270 | 286 | { |
271 | 287 | "cell_type": "code", |
272 | | - "execution_count": 9, |
| 288 | + "execution_count": 10, |
273 | 289 | "id": "2301054e", |
274 | 290 | "metadata": {}, |
275 | 291 | "outputs": [], |
|
329 | 345 | }, |
330 | 346 | { |
331 | 347 | "cell_type": "code", |
332 | | - "execution_count": 10, |
| 348 | + "execution_count": 11, |
333 | 349 | "id": "145d861c", |
334 | 350 | "metadata": {}, |
335 | 351 | "outputs": [ |
336 | 352 | { |
337 | 353 | "name": "stdout", |
338 | 354 | "output_type": "stream", |
339 | 355 | "text": [ |
340 | | - "Restored to: /Users/sghyseli/Projects/synthgeo-paper/flowtrain_stochastic_interpolation/project/geodata-3d-conditional/samples/jupyter-demo/paper_cond_gen_0\n" |
| 356 | + "Restored to: /home/sghys/projects/flowtrain_stochastic_interpolation/project/geodata-3d-conditional/samples/jupyter-demo/paper_cond_gen_0\n" |
341 | 357 | ] |
342 | 358 | } |
343 | 359 | ], |
|
360 | 376 | }, |
361 | 377 | { |
362 | 378 | "cell_type": "code", |
363 | | - "execution_count": 11, |
| 379 | + "execution_count": 12, |
364 | 380 | "id": "3afe5e1e", |
365 | 381 | "metadata": {}, |
366 | 382 | "outputs": [ |
367 | 383 | { |
368 | 384 | "data": { |
369 | 385 | "application/vnd.jupyter.widget-view+json": { |
370 | | - "model_id": "efbf0829bf2342e4bec7abe76ac086d8", |
| 386 | + "model_id": "0e02af847ae94161aebb114ee640b10b", |
371 | 387 | "version_major": 2, |
372 | 388 | "version_minor": 0 |
373 | 389 | }, |
374 | 390 | "text/plain": [ |
375 | | - "Widget(value='<iframe src=\"http://localhost:53764/index.html?ui=P_0x34fb582f0_1&reconnect=auto\" class=\"pyvista…" |
| 391 | + "Widget(value='<iframe src=\"http://localhost:36623/index.html?ui=P_0x7d484de739b0_1&reconnect=auto\" class=\"pyvi…" |
376 | 392 | ] |
377 | 393 | }, |
378 | 394 | "metadata": {}, |
|
381 | 397 | { |
382 | 398 | "data": { |
383 | 399 | "application/vnd.jupyter.widget-view+json": { |
384 | | - "model_id": "86bcac6e346a4d2eae3e1af847d1edd5", |
| 400 | + "model_id": "0d1d6198cffd40a78fc1fa18e19e6a53", |
385 | 401 | "version_major": 2, |
386 | 402 | "version_minor": 0 |
387 | 403 | }, |
388 | 404 | "text/plain": [ |
389 | | - "Widget(value='<iframe src=\"http://localhost:53764/index.html?ui=P_0x34def0260_2&reconnect=auto\" class=\"pyvista…" |
| 405 | + "Widget(value='<iframe src=\"http://localhost:36623/index.html?ui=P_0x7d484e2fa9f0_2&reconnect=auto\" class=\"pyvi…" |
390 | 406 | ] |
391 | 407 | }, |
392 | 408 | "metadata": {}, |
|
395 | 411 | ], |
396 | 412 | "source": [ |
397 | 413 | "sample_number = 0\n", |
398 | | - "geomodel, boreholes = load_model_and_boreholes(samples_dir, device=device)\n", |
399 | | - "solutions = load_solutions(samples_dir, sample_title=\"sample\", device=device)\n", |
| 414 | + "geomodel, boreholes = load_model_and_boreholes(samples_dir, device=\"cpu\")\n", |
| 415 | + "solutions = load_solutions(samples_dir, sample_title=\"sample\", device=\"cpu\")\n", |
400 | 416 | "show_model_and_boreholes(geomodel, boreholes)\n", |
401 | 417 | "# Limit to 10 solutions for display\n", |
402 | 418 | "show_solutions(solutions[0:10])" |
403 | 419 | ] |
404 | 420 | }, |
405 | 421 | { |
406 | 422 | "cell_type": "code", |
407 | | - "execution_count": 14, |
| 423 | + "execution_count": 15, |
408 | 424 | "id": "4e45a77a", |
409 | 425 | "metadata": {}, |
410 | 426 | "outputs": [], |
411 | 427 | "source": [ |
412 | | - "def vote_probabilities(\n", |
413 | | - " solutions: torch.Tensor, num_categories: int = 15\n", |
414 | | - ") -> torch.Tensor:\n", |
| 428 | + "def vote_probabilities(solutions: torch.Tensor, num_categories: int = 15) -> torch.Tensor:\n", |
415 | 429 | " \"\"\"\n", |
416 | | - " Compute per-voxel class probabilities by majority vote across the batch.\n", |
417 | | - " Input: [B,X,Y,Z] of categories and Output: [C,X,Y,Z] of probabilities\n", |
| 430 | + " Compute per-voxel class probabilities over a batch.\n", |
| 431 | + " Input: [B, X, Y, Z] integer categories (may include -1)\n", |
| 432 | + " Output: [C, X, Y, Z] float probabilities\n", |
418 | 433 | " \"\"\"\n", |
419 | 434 | " assert solutions.dim() == 4\n", |
420 | 435 | " B, X, Y, Z = solutions.shape\n", |
| 436 | + " device = solutions.device\n", |
421 | 437 | "\n", |
422 | | - " # Shift labels to 0..C-1 if they are -1..C-2\n", |
| 438 | + " # Handle negative indices (-1 for \"air\")\n", |
423 | 439 | " if solutions.min().item() < 0:\n", |
424 | | - " sol_shifted = solutions + 1\n", |
425 | | - " else:\n", |
426 | | - " sol_shifted = solutions\n", |
427 | | - " sol_shifted = sol_shifted.to(torch.long) # required by bincount\n", |
| 440 | + " solutions = solutions + 1 # shift to 0..C-1\n", |
| 441 | + "\n", |
| 442 | + " solutions = solutions.to(torch.long)\n", |
428 | 443 | "\n", |
429 | | - " sols_one_hot = (\n", |
430 | | - " torch.nn.functional.one_hot(sol_shifted, num_categories)\n", |
431 | | - " .permute(0, 4, 1, 2, 3)\n", |
432 | | - " .float()\n", |
433 | | - " ) # [B, 15, 64, 64, 64]\n", |
434 | | - " probability_vector = sols_one_hot.mean(dim=0, keepdim=False)\n", |
| 444 | + " # Accumulator for per-class voxel counts\n", |
| 445 | + " accumulator = torch.zeros(num_categories, X, Y, Z, dtype=torch.float32, device=device)\n", |
435 | 446 | "\n", |
436 | | - " return probability_vector\n", |
| 447 | + " # Accumulate one-hot for each sample\n", |
| 448 | + " for b in range(B):\n", |
| 449 | + " one_hot = torch.nn.functional.one_hot(solutions[b], num_classes=num_categories) # [X, Y, Z, C]\n", |
| 450 | + " one_hot = one_hot.permute(3, 0, 1, 2).float() # [C, X, Y, Z]\n", |
| 451 | + " accumulator += one_hot\n", |
| 452 | + "\n", |
| 453 | + " # Normalize by total samples\n", |
| 454 | + " probabilities = accumulator / B\n", |
| 455 | + " return probabilities\n", |
437 | 456 | "\n", |
438 | 457 | "\n", |
439 | 458 | "solution_probabilistic = vote_probabilities(solutions, num_categories=15)" |
|
450 | 469 | }, |
451 | 470 | { |
452 | 471 | "cell_type": "code", |
453 | | - "execution_count": 15, |
| 472 | + "execution_count": 16, |
454 | 473 | "id": "7c67d1ec", |
455 | 474 | "metadata": {}, |
456 | 475 | "outputs": [ |
457 | 476 | { |
458 | 477 | "data": { |
459 | 478 | "application/vnd.jupyter.widget-view+json": { |
460 | | - "model_id": "f9877221328e42079086c8e85e0aa00f", |
| 479 | + "model_id": "2fbba51d6ac146919348db865e5aad3d", |
461 | 480 | "version_major": 2, |
462 | 481 | "version_minor": 0 |
463 | 482 | }, |
464 | 483 | "text/plain": [ |
465 | | - "Widget(value='<iframe src=\"http://localhost:53764/index.html?ui=P_0x38e7f0260_5&reconnect=auto\" class=\"pyvista…" |
| 484 | + "Widget(value='<iframe src=\"http://localhost:36623/index.html?ui=P_0x7d483cb153d0_3&reconnect=auto\" class=\"pyvi…" |
466 | 485 | ] |
467 | 486 | }, |
468 | 487 | "metadata": {}, |
|
471 | 490 | { |
472 | 491 | "data": { |
473 | 492 | "application/vnd.jupyter.widget-view+json": { |
474 | | - "model_id": "eb98d17d3b954dd88d1abb5a16356fe2", |
| 493 | + "model_id": "5ba813e5a3b341ab86744048d92adc05", |
475 | 494 | "version_major": 2, |
476 | 495 | "version_minor": 0 |
477 | 496 | }, |
478 | 497 | "text/plain": [ |
479 | | - "Widget(value='<iframe src=\"http://localhost:53764/index.html?ui=P_0x38e831760_6&reconnect=auto\" class=\"pyvista…" |
| 498 | + "Widget(value='<iframe src=\"http://localhost:36623/index.html?ui=P_0x7d486e937620_4&reconnect=auto\" class=\"pyvi…" |
480 | 499 | ] |
481 | 500 | }, |
482 | 501 | "metadata": {}, |
|
621 | 640 | ], |
622 | 641 | "metadata": { |
623 | 642 | "kernelspec": { |
624 | | - "display_name": "geopaper", |
| 643 | + "display_name": "ml", |
625 | 644 | "language": "python", |
626 | 645 | "name": "python3" |
627 | 646 | }, |
|
635 | 654 | "name": "python", |
636 | 655 | "nbconvert_exporter": "python", |
637 | 656 | "pygments_lexer": "ipython3", |
638 | | - "version": "3.12.11" |
| 657 | + "version": "3.11.9" |
639 | 658 | } |
640 | 659 | }, |
641 | 660 | "nbformat": 4, |
|
0 commit comments