|
43 | 43 | "source": [
|
44 | 44 | "from matplotlib import pyplot as plt\n",
|
45 | 45 | "\n",
|
46 |
| - "plt.imshow(f[0].data, origin='lower', cmap='gray_r')\n", |
| 46 | + "plt.imshow(f[0].data, origin='lower', cmap='Greys_r')\n", |
47 | 47 | "print(\"Behold, the M2 globular cluster!\")"
|
48 | 48 | ]
|
49 | 49 | },
|
|
54 | 54 | "outputs": [],
|
55 | 55 | "source": [
|
56 | 56 | "logimage = np.log(f[0].data - f[0].data.min() + 1)\n",
|
57 |
| - "_ = plt.imshow(logimage, origin='lower', cmap='gray_r')" |
| 57 | + "plt.imshow(logimage, origin='lower', cmap='Greys_r');" |
| 58 | + ] |
| 59 | + }, |
| 60 | + { |
| 61 | + "cell_type": "code", |
| 62 | + "execution_count": null, |
| 63 | + "metadata": {}, |
| 64 | + "outputs": [], |
| 65 | + "source": [ |
| 66 | + "from matplotlib.patches import Rectangle\n", |
| 67 | + "\n", |
| 68 | + "plt.imshow(logimage, origin='lower', cmap='Greys_r')\n", |
| 69 | + "rect = Rectangle((310, 630), 100, 100, linewidth=2, edgecolor='r', facecolor='none')\n", |
| 70 | + "_ = plt.gca().add_patch(rect)\n", |
| 71 | + "plt.xticks([])\n", |
| 72 | + "plt.yticks([]);" |
58 | 73 | ]
|
59 | 74 | },
|
60 | 75 | {
|
|
90 | 105 | "metadata": {},
|
91 | 106 | "outputs": [],
|
92 | 107 | "source": [
|
93 |
| - "from matplotlib.patches import Rectangle\n", |
| 108 | + "original = f[0].data[630:730, 310:410]\n", |
| 109 | + "\n", |
| 110 | + "arcsinh_median = np.arcsinh((original - np.median(original)))\n", |
94 | 111 | "\n",
|
95 |
| - "plt.imshow(logimage, origin='lower', cmap='gray_r')\n", |
96 |
| - "plt.scatter(plocs_all[:, 1], plocs_all[:, 0], s=10, c='r')\n", |
97 |
| - "rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')\n", |
98 |
| - "plt.gca().add_patch(rect)" |
| 112 | + "clipped = original.clip(max=np.quantile(original, 0.98))\n", |
| 113 | + "arcsinh_clipped = np.arcsinh((clipped - np.median(clipped)));" |
99 | 114 | ]
|
100 | 115 | },
|
101 | 116 | {
|
|
104 | 119 | "metadata": {},
|
105 | 120 | "outputs": [],
|
106 | 121 | "source": [
|
107 |
| - "in_bounds = (plocs_all[:, 1] > 310) & (plocs_all[:, 1] < 410)\n", |
108 |
| - "in_bounds &= (plocs_all[:, 0] > 630) & (plocs_all[:, 0] < 730)\n", |
109 |
| - "in_bounds.sum()" |
| 122 | + "\n", |
| 123 | + "\n", |
| 124 | + "fig, axs = plt.subplots(1, 3, figsize=(10, 10))\n", |
| 125 | + "\n", |
| 126 | + "images = [original, arcsinh_median, arcsinh_clipped]\n", |
| 127 | + "titles = ['original', 'arcsinc', 'arcsinc with clipping']\n", |
| 128 | + "\n", |
| 129 | + "for i, img in enumerate(images):\n", |
| 130 | + " ax = axs[i]\n", |
| 131 | + " ax.imshow(img, origin='lower', cmap='Greys_r')\n", |
| 132 | + " ax.set_title(titles[i])\n", |
| 133 | + " ax.set_xticks([])\n", |
| 134 | + " ax.set_yticks([])\n", |
| 135 | + "\n", |
| 136 | + "plt.tight_layout()\n", |
| 137 | + "plt.show()" |
110 | 138 | ]
|
111 | 139 | },
|
112 | 140 | {
|
|
115 | 143 | "metadata": {},
|
116 | 144 | "outputs": [],
|
117 | 145 | "source": [
|
118 |
| - "plt.imshow(logimage, origin='lower', cmap='gray_r')\n", |
119 |
| - "plt.scatter(plocs_all[:, 1][in_bounds], plocs_all[:, 0][in_bounds], s=10, c='r')\n", |
120 |
| - "rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')\n", |
121 |
| - "_ = plt.gca().add_patch(rect)" |
| 146 | + "in_bounds = (plocs_all[:, 1] > 310) & (plocs_all[:, 1] < 410)\n", |
| 147 | + "in_bounds &= (plocs_all[:, 0] > 630) & (plocs_all[:, 0] < 730)\n", |
| 148 | + "in_bounds.sum()" |
122 | 149 | ]
|
123 | 150 | },
|
124 | 151 | {
|
|
220 | 247 | "true_tile_cat.n_sources.sum()"
|
221 | 248 | ]
|
222 | 249 | },
|
| 250 | + { |
| 251 | + "cell_type": "code", |
| 252 | + "execution_count": null, |
| 253 | + "metadata": {}, |
| 254 | + "outputs": [], |
| 255 | + "source": [ |
| 256 | + "fig, axs = plt.subplots(1, 3, figsize=(10, 10))\n", |
| 257 | + "\n", |
| 258 | + "cutoffs = [20, 22.065, 24]\n", |
| 259 | + "\n", |
| 260 | + "for i, cutoff in enumerate(cutoffs):\n", |
| 261 | + " is_bright = sdss_r_mag < cutoff\n", |
| 262 | + " plocs_square_bright = plocs_square[is_bright]\n", |
| 263 | + " ax = axs[i]\n", |
| 264 | + " ax.imshow(arcsinh_clipped, origin='lower', cmap='Greys_r')\n", |
| 265 | + " ax.scatter(plocs_square_bright[:, 1], plocs_square_bright[:, 0], s=5, c='r')\n", |
| 266 | + " ax.set_title(f\"magnitude < {cutoff}\")\n", |
| 267 | + " ax.set_xlim(0, 100)\n", |
| 268 | + " ax.set_ylim(0, 100)\n", |
| 269 | + " ax.set_xticks([])\n", |
| 270 | + " ax.set_yticks([])\n", |
| 271 | + "\n", |
| 272 | + "plt.tight_layout()\n", |
| 273 | + "plt.show()\n" |
| 274 | + ] |
| 275 | + }, |
223 | 276 | {
|
224 | 277 | "cell_type": "markdown",
|
225 | 278 | "metadata": {},
|
|
244 | 297 | "with initialize(config_path=\"../../case_studies/dependent_tiling/\", version_base=None):\n",
|
245 | 298 | " cfg = compose(\"m2_config\", {\n",
|
246 | 299 | " \"encoder.tiles_to_crop=3\",\n",
|
247 |
| - " \"predict.weight_save_path=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n", |
| 300 | + " \"predict.weight_save_path=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n", |
248 | 301 | " # \"encoder.double_detect=false\"\n",
|
249 | 302 | " })"
|
250 | 303 | ]
|
|
290 | 343 | "metadata": {},
|
291 | 344 | "outputs": [],
|
292 | 345 | "source": [
|
| 346 | + "starnet = {\n", |
| 347 | + " \"recall\": [0.95, 0.91, 0.79, 0.7, 0.7, 0.62, 0.59, 0.4],\n", |
| 348 | + " \"precision\": [0.96, 0.97, 0.79, 0.8, 0.68, 0.6, 0.45, 0.35]\n", |
| 349 | + "}\n", |
| 350 | + "\n", |
| 351 | + "starnet[\"f1\"] = 2 * np.array(starnet[\"recall\"]) * np.array(starnet[\"precision\"])\n", |
| 352 | + "starnet[\"f1\"] /= (np.array(starnet[\"recall\"]) + np.array(starnet[\"precision\"]))\n", |
| 353 | + "\n", |
293 | 354 | "for name, metric in metrics.items():\n",
|
294 |
| - " metric.plot()" |
| 355 | + " metric.plot()\n" |
| 356 | + ] |
| 357 | + }, |
| 358 | + { |
| 359 | + "cell_type": "markdown", |
| 360 | + "metadata": {}, |
| 361 | + "source": [ |
| 362 | + "Check calibration:" |
| 363 | + ] |
| 364 | + }, |
| 365 | + { |
| 366 | + "cell_type": "code", |
| 367 | + "execution_count": null, |
| 368 | + "metadata": {}, |
| 369 | + "outputs": [], |
| 370 | + "source": [ |
| 371 | + "%%capture\n", |
| 372 | + "counts = []\n", |
| 373 | + "\n", |
| 374 | + "for i in range(15):\n", |
| 375 | + " bliss_cats = predict(cfg.predict)\n", |
| 376 | + " bliss_cat_pair, = bliss_cats.values()\n", |
| 377 | + " bliss_cat = bliss_cat_pair[\"sample_cat\"].to_full_catalog()\n", |
| 378 | + " counts.append(bliss_cat.n_sources.sum())\n", |
| 379 | + "\n", |
| 380 | + "counts" |
| 381 | + ] |
| 382 | + }, |
| 383 | + { |
| 384 | + "cell_type": "code", |
| 385 | + "execution_count": null, |
| 386 | + "metadata": {}, |
| 387 | + "outputs": [], |
| 388 | + "source": [ |
| 389 | + "cs = torch.tensor([c.item() for c in counts]).float()\n", |
| 390 | + "cs.mean(), cs.quantile(0.05), cs.quantile(0.95)" |
| 391 | + ] |
| 392 | + }, |
| 393 | + { |
| 394 | + "cell_type": "markdown", |
| 395 | + "metadata": {}, |
| 396 | + "source": [ |
| 397 | + "### Independent tiling (baseline)" |
295 | 398 | ]
|
296 | 399 | },
|
297 | 400 | {
|
|
316 | 419 | "bliss_cat_marginal = bliss_cat_pair[\"mode_cat\"].to_full_catalog()\n",
|
317 | 420 | "matching = matcher.match_catalogs(true_cat, bliss_cat_marginal)\n",
|
318 | 421 | "metric = metrics(true_cat, bliss_cat_marginal, matching)\n",
|
319 |
| - "for name, m in metrics.items():\n", |
320 |
| - " m.plot()\n", |
| 422 | + "\n", |
| 423 | + "m = metrics[\"DetectionPerformance\"]\n", |
| 424 | + "m.plot()\n", |
321 | 425 | "\n",
|
322 | 426 | "metric[\"detection_recall\"], metric[\"detection_precision\"], metric[\"detection_f1\"]"
|
323 | 427 | ]
|
324 | 428 | },
|
| 429 | + { |
| 430 | + "cell_type": "code", |
| 431 | + "execution_count": null, |
| 432 | + "metadata": {}, |
| 433 | + "outputs": [], |
| 434 | + "source": [ |
| 435 | + "recall = m.n_true_matches / m.n_true_sources\n", |
| 436 | + "precision = m.n_est_matches / m.n_est_sources\n", |
| 437 | + "f1 = 2 * precision * recall / (precision + recall)\n", |
| 438 | + "real = {\"recall\": recall, \"precision\": precision, \"f1\": f1}" |
| 439 | + ] |
| 440 | + }, |
325 | 441 | {
|
326 | 442 | "cell_type": "markdown",
|
327 | 443 | "metadata": {},
|
|
339 | 455 | " cfg3 = compose(\"m2_config\", {\n",
|
340 | 456 | " \"train.trainer.logger=null\",\n",
|
341 | 457 | " \"train.trainer.max_epochs=0\",\n",
|
342 |
| - " \"train.pretrained_weights=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n", |
| 458 | + " \"train.pretrained_weights=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n", |
343 | 459 | " \"cached_simulator.cached_data_path=/data/scratch/regier/toy_m2\",\n",
|
344 | 460 | " \"+train.trainer.num_sanity_val_steps=0\",\n",
|
345 | 461 | "# \"encoder.double_detect=false\"\n",
|
|
388 | 504 | "outputs": [],
|
389 | 505 | "source": [
|
390 | 506 | "obs_image = torch.from_numpy(dataset[0][\"image\"][2][6:-6, 6:-6])\n",
|
391 |
| - "plt.imshow(obs_image)\n", |
| 507 | + "plt.imshow(obs_image, origin='lower', cmap='Greys_r')\n", |
392 | 508 | "_ = plt.colorbar()"
|
393 | 509 | ]
|
394 | 510 | },
|
|
409 | 525 | "outputs": [],
|
410 | 526 | "source": [
|
411 | 527 | "true_recon_all = truth_images[0][2] + dataset[0][\"background\"][2][6:-6, 6:-6]\n",
|
412 |
| - "plt.imshow(true_recon_all)\n", |
| 528 | + "plt.imshow(true_recon_all, origin='lower', cmap='Greys_r')\n", |
413 | 529 | "_ = plt.colorbar()"
|
414 | 530 | ]
|
415 | 531 | },
|
|
430 | 546 | "outputs": [],
|
431 | 547 | "source": [
|
432 | 548 | "true_recon = truth_images[0][2] + dataset[0][\"background\"][2][6:-6, 6:-6]\n",
|
433 |
| - "plt.imshow(true_recon)\n", |
| 549 | + "plt.imshow(true_recon, origin='lower', cmap='Greys_r')\n", |
434 | 550 | "_ = plt.colorbar()"
|
435 | 551 | ]
|
436 | 552 | },
|
|
451 | 567 | "outputs": [],
|
452 | 568 | "source": [
|
453 | 569 | "bliss_recon = bliss_images[0, 2] + dataset[0][\"background\"][2][6:-6, 6:-6]\n",
|
454 |
| - "plt.imshow(bliss_recon)\n", |
| 570 | + "plt.imshow(bliss_recon, origin='lower', cmap='Greys_r')\n", |
455 | 571 | "_ = plt.colorbar()"
|
456 | 572 | ]
|
457 | 573 | },
|
|
576 | 692 | " cfg5 = compose(\"m2_config\", {\n",
|
577 | 693 | " \"train.trainer.logger=null\",\n",
|
578 | 694 | " \"train.trainer.max_epochs=0\",\n",
|
579 |
| - " \"train.pretrained_weights=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n", |
| 695 | + " \"train.pretrained_weights=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n", |
580 | 696 | " \"cached_simulator.cached_data_path=/data/scratch/regier/toy_m2\",\n",
|
581 | 697 | " \"+train.trainer.num_sanity_val_steps=0\",\n",
|
582 | 698 | " \"cached_simulator.splits=0:10/10:20/0:100\",\n",
|
|
631 | 747 | "with initialize(config_path=\"../../case_studies/dependent_tiling/\", version_base=None):\n",
|
632 | 748 | " cfg = compose(\"m2_config\", {\n",
|
633 | 749 | " \"encoder.tiles_to_crop=3\",\n",
|
634 |
| - " \"predict.weight_save_path=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n", |
| 750 | + " \"predict.weight_save_path=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n", |
635 | 751 | " # \"encoder.double_detect=false\"\n",
|
636 | 752 | " })\n",
|
637 | 753 | "\n",
|
|
670 | 786 | "outputs": [],
|
671 | 787 | "source": [
|
672 | 788 | "encoder = instantiate(cfg.encoder)\n",
|
673 |
| - "enc_state_dict = torch.load(\"/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\")\n", |
| 789 | + "enc_state_dict = torch.load(\"/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\")\n", |
674 | 790 | "enc_state_dict = enc_state_dict[\"state_dict\"]\n",
|
675 | 791 | "encoder.load_state_dict(enc_state_dict)\n",
|
676 | 792 | "encoder.eval()\n",
|
|
0 commit comments