Skip to content

Commit

Permalink
improving m2 case study
Browse files Browse the repository at this point in the history
  • Loading branch information
jeff-regier committed Jan 28, 2024
1 parent 0dd540e commit 2d883bb
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 343 deletions.
168 changes: 142 additions & 26 deletions case_studies/dependent_tiling/m2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"source": [
"from matplotlib import pyplot as plt\n",
"\n",
"plt.imshow(f[0].data, origin='lower', cmap='gray_r')\n",
"plt.imshow(f[0].data, origin='lower', cmap='Greys_r')\n",
"print(\"Behold, the M2 globular cluster!\")"
]
},
Expand All @@ -54,7 +54,22 @@
"outputs": [],
"source": [
"logimage = np.log(f[0].data - f[0].data.min() + 1)\n",
"_ = plt.imshow(logimage, origin='lower', cmap='gray_r')"
"plt.imshow(logimage, origin='lower', cmap='Greys_r');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib.patches import Rectangle\n",
"\n",
"plt.imshow(logimage, origin='lower', cmap='Greys_r')\n",
"rect = Rectangle((310, 630), 100, 100, linewidth=2, edgecolor='r', facecolor='none')\n",
"_ = plt.gca().add_patch(rect)\n",
"plt.xticks([])\n",
"plt.yticks([]);"
]
},
{
Expand Down Expand Up @@ -90,12 +105,12 @@
"metadata": {},
"outputs": [],
"source": [
"from matplotlib.patches import Rectangle\n",
"original = f[0].data[630:730, 310:410]\n",
"\n",
"arcsinh_median = np.arcsinh((original - np.median(original)))\n",
"\n",
"plt.imshow(logimage, origin='lower', cmap='gray_r')\n",
"plt.scatter(plocs_all[:, 1], plocs_all[:, 0], s=10, c='r')\n",
"rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')\n",
"plt.gca().add_patch(rect)"
"clipped = original.clip(max=np.quantile(original, 0.98))\n",
"arcsinh_clipped = np.arcsinh((clipped - np.median(clipped)));"
]
},
{
Expand All @@ -104,9 +119,22 @@
"metadata": {},
"outputs": [],
"source": [
"in_bounds = (plocs_all[:, 1] > 310) & (plocs_all[:, 1] < 410)\n",
"in_bounds &= (plocs_all[:, 0] > 630) & (plocs_all[:, 0] < 730)\n",
"in_bounds.sum()"
"\n",
"\n",
"fig, axs = plt.subplots(1, 3, figsize=(10, 10))\n",
"\n",
"images = [original, arcsinh_median, arcsinh_clipped]\n",
"titles = ['original', 'arcsinc', 'arcsinc with clipping']\n",
"\n",
"for i, img in enumerate(images):\n",
" ax = axs[i]\n",
" ax.imshow(img, origin='lower', cmap='Greys_r')\n",
" ax.set_title(titles[i])\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
Expand All @@ -115,10 +143,9 @@
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(logimage, origin='lower', cmap='gray_r')\n",
"plt.scatter(plocs_all[:, 1][in_bounds], plocs_all[:, 0][in_bounds], s=10, c='r')\n",
"rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')\n",
"_ = plt.gca().add_patch(rect)"
"in_bounds = (plocs_all[:, 1] > 310) & (plocs_all[:, 1] < 410)\n",
"in_bounds &= (plocs_all[:, 0] > 630) & (plocs_all[:, 0] < 730)\n",
"in_bounds.sum()"
]
},
{
Expand Down Expand Up @@ -220,6 +247,32 @@
"true_tile_cat.n_sources.sum()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(1, 3, figsize=(10, 10))\n",
"\n",
"cutoffs = [20, 22.065, 24]\n",
"\n",
"for i, cutoff in enumerate(cutoffs):\n",
" is_bright = sdss_r_mag < cutoff\n",
" plocs_square_bright = plocs_square[is_bright]\n",
" ax = axs[i]\n",
" ax.imshow(arcsinh_clipped, origin='lower', cmap='Greys_r')\n",
" ax.scatter(plocs_square_bright[:, 1], plocs_square_bright[:, 0], s=5, c='r')\n",
" ax.set_title(f\"magnitude < {cutoff}\")\n",
" ax.set_xlim(0, 100)\n",
" ax.set_ylim(0, 100)\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -244,7 +297,7 @@
"with initialize(config_path=\"../../case_studies/dependent_tiling/\", version_base=None):\n",
" cfg = compose(\"m2_config\", {\n",
" \"encoder.tiles_to_crop=3\",\n",
" \"predict.weight_save_path=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n",
" \"predict.weight_save_path=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n",
" # \"encoder.double_detect=false\"\n",
" })"
]
Expand Down Expand Up @@ -290,8 +343,58 @@
"metadata": {},
"outputs": [],
"source": [
"starnet = {\n",
" \"recall\": [0.95, 0.91, 0.79, 0.7, 0.7, 0.62, 0.59, 0.4],\n",
" \"precision\": [0.96, 0.97, 0.79, 0.8, 0.68, 0.6, 0.45, 0.35]\n",
"}\n",
"\n",
"starnet[\"f1\"] = 2 * np.array(starnet[\"recall\"]) * np.array(starnet[\"precision\"])\n",
"starnet[\"f1\"] /= (np.array(starnet[\"recall\"]) + np.array(starnet[\"precision\"]))\n",
"\n",
"for name, metric in metrics.items():\n",
" metric.plot()"
" metric.plot()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check calibration:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"counts = []\n",
"\n",
"for i in range(15):\n",
" bliss_cats = predict(cfg.predict)\n",
" bliss_cat_pair, = bliss_cats.values()\n",
" bliss_cat = bliss_cat_pair[\"sample_cat\"].to_full_catalog()\n",
" counts.append(bliss_cat.n_sources.sum())\n",
"\n",
"counts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cs = torch.tensor([c.item() for c in counts]).float()\n",
"cs.mean(), cs.quantile(0.05), cs.quantile(0.95)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Independent tiling (baseline)"
]
},
{
Expand All @@ -316,12 +419,25 @@
"bliss_cat_marginal = bliss_cat_pair[\"mode_cat\"].to_full_catalog()\n",
"matching = matcher.match_catalogs(true_cat, bliss_cat_marginal)\n",
"metric = metrics(true_cat, bliss_cat_marginal, matching)\n",
"for name, m in metrics.items():\n",
" m.plot()\n",
"\n",
"m = metrics[\"DetectionPerformance\"]\n",
"m.plot()\n",
"\n",
"metric[\"detection_recall\"], metric[\"detection_precision\"], metric[\"detection_f1\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"recall = m.n_true_matches / m.n_true_sources\n",
"precision = m.n_est_matches / m.n_est_sources\n",
"f1 = 2 * precision * recall / (precision + recall)\n",
"real = {\"recall\": recall, \"precision\": precision, \"f1\": f1}"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -339,7 +455,7 @@
" cfg3 = compose(\"m2_config\", {\n",
" \"train.trainer.logger=null\",\n",
" \"train.trainer.max_epochs=0\",\n",
" \"train.pretrained_weights=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n",
" \"train.pretrained_weights=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n",
" \"cached_simulator.cached_data_path=/data/scratch/regier/toy_m2\",\n",
" \"+train.trainer.num_sanity_val_steps=0\",\n",
"# \"encoder.double_detect=false\"\n",
Expand Down Expand Up @@ -388,7 +504,7 @@
"outputs": [],
"source": [
"obs_image = torch.from_numpy(dataset[0][\"image\"][2][6:-6, 6:-6])\n",
"plt.imshow(obs_image)\n",
"plt.imshow(obs_image, origin='lower', cmap='Greys_r')\n",
"_ = plt.colorbar()"
]
},
Expand All @@ -409,7 +525,7 @@
"outputs": [],
"source": [
"true_recon_all = truth_images[0][2] + dataset[0][\"background\"][2][6:-6, 6:-6]\n",
"plt.imshow(true_recon_all)\n",
"plt.imshow(true_recon_all, origin='lower', cmap='Greys_r')\n",
"_ = plt.colorbar()"
]
},
Expand All @@ -430,7 +546,7 @@
"outputs": [],
"source": [
"true_recon = truth_images[0][2] + dataset[0][\"background\"][2][6:-6, 6:-6]\n",
"plt.imshow(true_recon)\n",
"plt.imshow(true_recon, origin='lower', cmap='Greys_r')\n",
"_ = plt.colorbar()"
]
},
Expand All @@ -451,7 +567,7 @@
"outputs": [],
"source": [
"bliss_recon = bliss_images[0, 2] + dataset[0][\"background\"][2][6:-6, 6:-6]\n",
"plt.imshow(bliss_recon)\n",
"plt.imshow(bliss_recon, origin='lower', cmap='Greys_r')\n",
"_ = plt.colorbar()"
]
},
Expand Down Expand Up @@ -576,7 +692,7 @@
" cfg5 = compose(\"m2_config\", {\n",
" \"train.trainer.logger=null\",\n",
" \"train.trainer.max_epochs=0\",\n",
" \"train.pretrained_weights=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n",
" \"train.pretrained_weights=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n",
" \"cached_simulator.cached_data_path=/data/scratch/regier/toy_m2\",\n",
" \"+train.trainer.num_sanity_val_steps=0\",\n",
" \"cached_simulator.splits=0:10/10:20/0:100\",\n",
Expand Down Expand Up @@ -631,7 +747,7 @@
"with initialize(config_path=\"../../case_studies/dependent_tiling/\", version_base=None):\n",
" cfg = compose(\"m2_config\", {\n",
" \"encoder.tiles_to_crop=3\",\n",
" \"predict.weight_save_path=/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\",\n",
" \"predict.weight_save_path=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\",\n",
" # \"encoder.double_detect=false\"\n",
" })\n",
"\n",
Expand Down Expand Up @@ -670,7 +786,7 @@
"outputs": [],
"source": [
"encoder = instantiate(cfg.encoder)\n",
"enc_state_dict = torch.load(\"/home/regier/bliss/output/mean_sources/version_0/checkpoints/best_encoder.ckpt\")\n",
"enc_state_dict = torch.load(\"/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt\")\n",
"enc_state_dict = enc_state_dict[\"state_dict\"]\n",
"encoder.load_state_dict(enc_state_dict)\n",
"encoder.eval()\n",
Expand Down
Loading

0 comments on commit 2d883bb

Please sign in to comment.