Skip to content

Commit 2d883bb

Browse files
committed
improving m2 case study
1 parent 0dd540e commit 2d883bb

File tree

3 files changed

+156
-343
lines changed

3 files changed

+156
-343
lines changed

case_studies/dependent_tiling/m2.ipynb

Lines changed: 142 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"source": [
4444
"from matplotlib import pyplot as plt\n",
4545
"\n",
46-
"plt.imshow(f[0].data, origin='lower', cmap='gray_r')\n",
46+
"plt.imshow(f[0].data, origin='lower', cmap='Greys_r')\n",
4747
"print(\"Behold, the M2 globular cluster!\")"
4848
]
4949
},
@@ -54,7 +54,22 @@
5454
"outputs": [],
5555
"source": [
5656
"logimage = np.log(f[0].data - f[0].data.min() + 1)\n",
57-
"_ = plt.imshow(logimage, origin='lower', cmap='gray_r')"
57+
"plt.imshow(logimage, origin='lower', cmap='Greys_r');"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": null,
63+
"metadata": {},
64+
"outputs": [],
65+
"source": [
66+
"from matplotlib.patches import Rectangle\n",
67+
"\n",
68+
"plt.imshow(logimage, origin='lower', cmap='Greys_r')\n",
69+
"rect = Rectangle((310, 630), 100, 100, linewidth=2, edgecolor='r', facecolor='none')\n",
70+
"_ = plt.gca().add_patch(rect)\n",
71+
"plt.xticks([])\n",
72+
"plt.yticks([]);"
5873
]
5974
},
6075
{
@@ -90,12 +105,12 @@
90105
"metadata": {},
91106
"outputs": [],
92107
"source": [
93-
"from matplotlib.patches import Rectangle\n",
108+
"original = f[0].data[630:730, 310:410]\n",
109+
"\n",
110+
"arcsinh_median = np.arcsinh((original - np.median(original)))\n",
94111
"\n",
95-
"plt.imshow(logimage, origin='lower', cmap='gray_r')\n",
96-
"plt.scatter(plocs_all[:, 1], plocs_all[:, 0], s=10, c='r')\n",
97-
"rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')\n",
98-
"plt.gca().add_patch(rect)"
112+
"clipped = original.clip(max=np.quantile(original, 0.98))\n",
113+
"arcsinh_clipped = np.arcsinh((clipped - np.median(clipped)));"
99114
]
100115
},
101116
{
@@ -104,9 +119,22 @@
104119
"metadata": {},
105120
"outputs": [],
106121
"source": [
107-
"in_bounds = (plocs_all[:, 1] > 310) & (plocs_all[:, 1] < 410)\n",
108-
"in_bounds &= (plocs_all[:, 0] > 630) & (plocs_all[:, 0] < 730)\n",
109-
"in_bounds.sum()"
122+
"\n",
123+
"\n",
124+
"fig, axs = plt.subplots(1, 3, figsize=(10, 10))\n",
125+
"\n",
126+
"images = [original, arcsinh_median, arcsinh_clipped]\n",
127+
"titles = ['original', 'arcsinc', 'arcsinc with clipping']\n",
128+
"\n",
129+
"for i, img in enumerate(images):\n",
130+
" ax = axs[i]\n",
131+
" ax.imshow(img, origin='lower', cmap='Greys_r')\n",
132+
" ax.set_title(titles[i])\n",
133+
" ax.set_xticks([])\n",
134+
" ax.set_yticks([])\n",
135+
"\n",
136+
"plt.tight_layout()\n",
137+
"plt.show()"
110138
]
111139
},
112140
{
@@ -115,10 +143,9 @@
115143
"metadata": {},
116144
"outputs": [],
117145
"source": [
118-
"plt.imshow(logimage, origin='lower', cmap='gray_r')\n",
119-
"plt.scatter(plocs_all[:, 1][in_bounds], plocs_all[:, 0][in_bounds], s=10, c='r')\n",
120-
"rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')\n",
121-
"_ = plt.gca().add_patch(rect)"
146+
"in_bounds = (plocs_all[:, 1] > 310) & (plocs_all[:, 1] < 410)\n",
147+
"in_bounds &= (plocs_all[:, 0] > 630) & (plocs_all[:, 0] < 730)\n",
148+
"in_bounds.sum()"
122149
]
123150
},
124151
{
@@ -220,6 +247,32 @@
220247
"true_tile_cat.n_sources.sum()"
221248
]
222249
},
250+
{
251+
"cell_type": "code",
252+
"execution_count": null,
253+
"metadata": {},
254+
"outputs": [],
255+
"source": [
256+
"fig, axs = plt.subplots(1, 3, figsize=(10, 10))\n",
257+
"\n",
258+
"cutoffs = [20, 22.065, 24]\n",
259+
"\n",
260+
"for i, cutoff in enumerate(cutoffs):\n",
261+
" is_bright = sdss_r_mag < cutoff\n",
262+
" plocs_square_bright = plocs_square[is_bright]\n",
263+
" ax = axs[i]\n",
264+
" ax.imshow(arcsinh_clipped, origin='lower', cmap='Greys_r')\n",
265+
" ax.scatter(plocs_square_bright[:, 1], plocs_square_bright[:, 0], s=5, c='r')\n",
266+
" ax.set_title(f\"magnitude < {cutoff}\")\n",
267+
" ax.set_xlim(0, 100)\n",
268+
" ax.set_ylim(0, 100)\n",
269+
" ax.set_xticks([])\n",
270+
" ax.set_yticks([])\n",
271+
"\n",
272+
"plt.tight_layout()\n",
273+
"plt.show()\n"
274+
]
275+
},
223276
{
224277
"cell_type": "markdown",
225278
"metadata": {},
@@ -244,7 +297,7 @@
244297
"with initialize(config_path=\"../../case_studies/dependent_tiling/\", version_base=None):\n",
245298
" cfg = compose(\"m2_config\", {\n",
246299
" \"encoder.tiles_to_crop=3\",\n",
247-
" \"predict.weight_save_path=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n",
300+
" \"predict.weight_save_path=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n",
248301
" # \"encoder.double_detect=false\"\n",
249302
" })"
250303
]
@@ -290,8 +343,58 @@
290343
"metadata": {},
291344
"outputs": [],
292345
"source": [
346+
"starnet = {\n",
347+
" \"recall\": [0.95, 0.91, 0.79, 0.7, 0.7, 0.62, 0.59, 0.4],\n",
348+
" \"precision\": [0.96, 0.97, 0.79, 0.8, 0.68, 0.6, 0.45, 0.35]\n",
349+
"}\n",
350+
"\n",
351+
"starnet[\"f1\"] = 2 * np.array(starnet[\"recall\"]) * np.array(starnet[\"precision\"])\n",
352+
"starnet[\"f1\"] /= (np.array(starnet[\"recall\"]) + np.array(starnet[\"precision\"]))\n",
353+
"\n",
293354
"for name, metric in metrics.items():\n",
294-
" metric.plot()"
355+
" metric.plot()\n"
356+
]
357+
},
358+
{
359+
"cell_type": "markdown",
360+
"metadata": {},
361+
"source": [
362+
"Check calibration:"
363+
]
364+
},
365+
{
366+
"cell_type": "code",
367+
"execution_count": null,
368+
"metadata": {},
369+
"outputs": [],
370+
"source": [
371+
"%%capture\n",
372+
"counts = []\n",
373+
"\n",
374+
"for i in range(15):\n",
375+
" bliss_cats = predict(cfg.predict)\n",
376+
" bliss_cat_pair, = bliss_cats.values()\n",
377+
" bliss_cat = bliss_cat_pair[\"sample_cat\"].to_full_catalog()\n",
378+
" counts.append(bliss_cat.n_sources.sum())\n",
379+
"\n",
380+
"counts"
381+
]
382+
},
383+
{
384+
"cell_type": "code",
385+
"execution_count": null,
386+
"metadata": {},
387+
"outputs": [],
388+
"source": [
389+
"cs = torch.tensor([c.item() for c in counts]).float()\n",
390+
"cs.mean(), cs.quantile(0.05), cs.quantile(0.95)"
391+
]
392+
},
393+
{
394+
"cell_type": "markdown",
395+
"metadata": {},
396+
"source": [
397+
"### Independent tiling (baseline)"
295398
]
296399
},
297400
{
@@ -316,12 +419,25 @@
316419
"bliss_cat_marginal = bliss_cat_pair[\"mode_cat\"].to_full_catalog()\n",
317420
"matching = matcher.match_catalogs(true_cat, bliss_cat_marginal)\n",
318421
"metric = metrics(true_cat, bliss_cat_marginal, matching)\n",
319-
"for name, m in metrics.items():\n",
320-
" m.plot()\n",
422+
"\n",
423+
"m = metrics[\"DetectionPerformance\"]\n",
424+
"m.plot()\n",
321425
"\n",
322426
"metric[\"detection_recall\"], metric[\"detection_precision\"], metric[\"detection_f1\"]"
323427
]
324428
},
429+
{
430+
"cell_type": "code",
431+
"execution_count": null,
432+
"metadata": {},
433+
"outputs": [],
434+
"source": [
435+
"recall = m.n_true_matches / m.n_true_sources\n",
436+
"precision = m.n_est_matches / m.n_est_sources\n",
437+
"f1 = 2 * precision * recall / (precision + recall)\n",
438+
"real = {\"recall\": recall, \"precision\": precision, \"f1\": f1}"
439+
]
440+
},
325441
{
326442
"cell_type": "markdown",
327443
"metadata": {},
@@ -339,7 +455,7 @@
339455
" cfg3 = compose(\"m2_config\", {\n",
340456
" \"train.trainer.logger=null\",\n",
341457
" \"train.trainer.max_epochs=0\",\n",
342-
" \"train.pretrained_weights=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n",
458+
" \"train.pretrained_weights=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n",
343459
" \"cached_simulator.cached_data_path=/data/scratch/regier/toy_m2\",\n",
344460
" \"+train.trainer.num_sanity_val_steps=0\",\n",
345461
"# \"encoder.double_detect=false\"\n",
@@ -388,7 +504,7 @@
388504
"outputs": [],
389505
"source": [
390506
"obs_image = torch.from_numpy(dataset[0][\"image\"][2][6:-6, 6:-6])\n",
391-
"plt.imshow(obs_image)\n",
507+
"plt.imshow(obs_image, origin='lower', cmap='Greys_r')\n",
392508
"_ = plt.colorbar()"
393509
]
394510
},
@@ -409,7 +525,7 @@
409525
"outputs": [],
410526
"source": [
411527
"true_recon_all = truth_images[0][2] + dataset[0][\"background\"][2][6:-6, 6:-6]\n",
412-
"plt.imshow(true_recon_all)\n",
528+
"plt.imshow(true_recon_all, origin='lower', cmap='Greys_r')\n",
413529
"_ = plt.colorbar()"
414530
]
415531
},
@@ -430,7 +546,7 @@
430546
"outputs": [],
431547
"source": [
432548
"true_recon = truth_images[0][2] + dataset[0][\"background\"][2][6:-6, 6:-6]\n",
433-
"plt.imshow(true_recon)\n",
549+
"plt.imshow(true_recon, origin='lower', cmap='Greys_r')\n",
434550
"_ = plt.colorbar()"
435551
]
436552
},
@@ -451,7 +567,7 @@
451567
"outputs": [],
452568
"source": [
453569
"bliss_recon = bliss_images[0, 2] + dataset[0][\"background\"][2][6:-6, 6:-6]\n",
454-
"plt.imshow(bliss_recon)\n",
570+
"plt.imshow(bliss_recon, origin='lower', cmap='Greys_r')\n",
455571
"_ = plt.colorbar()"
456572
]
457573
},
@@ -576,7 +692,7 @@
576692
" cfg5 = compose(\"m2_config\", {\n",
577693
" \"train.trainer.logger=null\",\n",
578694
" \"train.trainer.max_epochs=0\",\n",
579-
" \"train.pretrained_weights=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n",
695+
" \"train.pretrained_weights=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n",
580696
" \"cached_simulator.cached_data_path=/data/scratch/regier/toy_m2\",\n",
581697
" \"+train.trainer.num_sanity_val_steps=0\",\n",
582698
" \"cached_simulator.splits=0:10/10:20/0:100\",\n",
@@ -631,7 +747,7 @@
631747
"with initialize(config_path=\"../../case_studies/dependent_tiling/\", version_base=None):\n",
632748
" cfg = compose(\"m2_config\", {\n",
633749
" \"encoder.tiles_to_crop=3\",\n",
634-
" \"predict.weight_save_path=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n",
750+
" \"predict.weight_save_path=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n",
635751
" # \"encoder.double_detect=false\"\n",
636752
" })\n",
637753
"\n",
@@ -670,7 +786,7 @@
670786
"outputs": [],
671787
"source": [
672788
"encoder = instantiate(cfg.encoder)\n",
673-
"enc_state_dict = torch.load(\"/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\")\n",
789+
"enc_state_dict = torch.load(\"/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\")\n",
674790
"enc_state_dict = enc_state_dict[\"state_dict\"]\n",
675791
"encoder.load_state_dict(enc_state_dict)\n",
676792
"encoder.eval()\n",

0 commit comments

Comments
 (0)