diff --git a/bliss/api.py b/bliss/api.py index 24dcb9bdd..5b385ef20 100644 --- a/bliss/api.py +++ b/bliss/api.py @@ -1,6 +1,6 @@ import base64 from pathlib import Path -from typing import Dict, Literal, Optional, Tuple +from typing import Dict, Literal, Optional, Tuple, TypeAlias import requests import torch @@ -17,10 +17,12 @@ from bliss.surveys import sdss_download from bliss.train import train as _train -SurveyType = Literal["decals", "hst", "lsst", "sdss"] +SurveyType: TypeAlias = Literal["decals", "hst", "lsst", "sdss"] class BlissClient: + """Client for interacting with the BLISS API.""" + def __init__(self, cwd: str): self._cwd = cwd # cached_data_path (str): Path to directory where cached data will be stored. diff --git a/docs/docsrc/conf.py b/docs/docsrc/conf.py index 90f014e20..fdc6f293d 100644 --- a/docs/docsrc/conf.py +++ b/docs/docsrc/conf.py @@ -66,3 +66,35 @@ # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] + +# For nbsphinx +nbsphinx_execute = "never" + +# Below code is necessary to ensure pandoc is available for use by nbsphinx. +# See https://stackoverflow.com/questions/62398231/building-docs-fails-due-to-missing-pandoc/71585691#71585691 # noqa: E501 # pylint: disable=line-too-long +from inspect import getsourcefile # noqa: E402 # pylint: disable=wrong-import-position + +# Get path to directory containing this file, conf.py. +PATH_OF_THIS_FILE = getsourcefile(lambda: 0) # noqa: WPS522 +DOCS_DIRECTORY = os.path.dirname(os.path.abspath(PATH_OF_THIS_FILE)) # type: ignore + + +def ensure_pandoc_installed(_): + import pypandoc # pylint: disable=import-outside-toplevel + + # Download pandoc if necessary. If pandoc is already installed and on + # the PATH, the installed version will be used. Otherwise, we will + # download a copy of pandoc into docs/bin/ and add that to our PATH. + pandoc_dir = os.path.join(DOCS_DIRECTORY, "bin") + # Add dir containing pandoc binary to the PATH environment variable + if pandoc_dir not in os.environ["PATH"].split(os.pathsep): + os.environ["PATH"] += os.pathsep + pandoc_dir + pypandoc.ensure_pandoc_installed( + version="2.11.4", + targetfolder=pandoc_dir, + delete_installer=True, + ) + + +def setup(app): + app.connect("builder-inited", ensure_pandoc_installed) diff --git a/docs/docsrc/tutorials/index.rst b/docs/docsrc/tutorials/index.rst new file mode 100644 index 000000000..d5a2e5ea4 --- /dev/null +++ b/docs/docsrc/tutorials/index.rst @@ -0,0 +1,5 @@ +Tutorials +========= + +.. toctree:: + notebooks/tutorial diff --git a/docs/docsrc/tutorials/notebooks/est_cat.fits b/docs/docsrc/tutorials/notebooks/est_cat.fits new file mode 100644 index 000000000..93f6d900e --- /dev/null +++ b/docs/docsrc/tutorials/notebooks/est_cat.fits @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f9230ff450904e505d787169070409274a359b8cbd97f86d738ef7e87eefc6a +size 14400 diff --git a/docs/docsrc/tutorials/notebooks/predict.html b/docs/docsrc/tutorials/notebooks/predict.html new file mode 100644 index 000000000..8cdc7ec16 --- /dev/null +++ b/docs/docsrc/tutorials/notebooks/predict.html @@ -0,0 +1,60 @@ + + + + + Bokeh Plot + + + + + +
+ + + + + diff --git a/docs/docsrc/tutorials/notebooks/tutorial.ipynb b/docs/docsrc/tutorials/notebooks/tutorial.ipynb new file mode 100644 index 000000000..6b3c6e88c --- /dev/null +++ b/docs/docsrc/tutorials/notebooks/tutorial.ipynb @@ -0,0 +1,1455 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# BLISS AstroPy Affiliate Package\n", + "\n", + "Bayesian Light Source Separator (BLISS) is a Bayesian method for deblending and cataloging light sources." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: BLISS_HOME=/home/zhteoh/730-astropy-integration\n" + ] + } + ], + "source": [ + "%env BLISS_HOME=/home/zhteoh/730-astropy-integration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "!pip install -e $BLISS_HOME" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tutorial" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "from bliss.api import BlissClient\n", + "\n", + "bliss_client = BlissClient(cwd=\"/data/scratch/zhteoh/730-tutorial\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train the model" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generate synthetic image data" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data will be saved to /data/scratch/zhteoh/730-tutorial/dataset\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Simulating images in batches for file: 100%|██████████| 2/2 [00:55<00:00, 27.61s/it]\n", + "Simulating images in batches for file: 100%|██████████| 2/2 [00:57<00:00, 28.74s/it]9s/it]\n", + "Generating and writing cached dataset files: 100%|██████████| 2/2 [01:53<00:00, 56.56s/it]\n" + ] + } + ], + "source": [ + "bliss_client.generate(\n", + " n_batches=3, \n", + " batch_size=64, \n", + " max_images_per_file=128\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Pass additional custom configuration parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data will be saved to /data/scratch/zhteoh/730-tutorial/dataset_ms0.02\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Simulating images in batches for file: 100%|██████████| 2/2 [00:14<00:00, 7.01s/it]\n", + "Simulating images in batches for file: 100%|██████████| 2/2 [00:13<00:00, 6.89s/it]9s/it]\n", + "Generating and writing cached dataset files: 100%|██████████| 2/2 [00:27<00:00, 13.97s/it]\n" + ] + } + ], + "source": [ + "# Alter default cached_data_path\n", + "bliss_client.cached_data_path = \"/data/scratch/zhteoh/730-tutorial/dataset_ms0.02\"\n", + "\n", + "bliss_client.generate(\n", + " n_batches=3, # required\n", + " batch_size=64, # required\n", + " max_images_per_file=128, # required\n", + " simulator={\"prior\": {\"mean_sources\": 0.02}}, # optional\n", + " generate={\"file_prefix\": \"dataset\"}, # optional\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dataset_0.pt dataset_1.pt hparams.yaml\n", + "18M\t/data/scratch/zhteoh/730-tutorial/dataset_ms0.02\n", + "128\n", + "torch.Size([1, 80, 80])\n" + ] + } + ], + "source": [ + "# Check that the dataset is generated\n", + "!ls /data/scratch/zhteoh/730-tutorial/dataset_ms0.02\n", + "!du -sh /data/scratch/zhteoh/730-tutorial/dataset_ms0.02\n", + "# !cat /data/scratch/zhteoh/730-tutorial/dataset/hparams.yaml\n", + "\n", + "dataset_0 = bliss_client.get_dataset_file(filename=\"dataset_0.pt\")\n", + "print(len(dataset_0))\n", + "print(dataset_0[0][\"images\"].shape)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Train the model" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Without pretrained weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bliss_client.train(weight_save_path=\"tutorial_encoder/0.pt\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### With pretrained weights\n", + "\n", + "Download our relevant pretrained weights for your sky survey." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sdss_pretrained.pt sdss.pt sdss.pt.log.json\n" + ] + } + ], + "source": [ + "import os\n", + "assert os.path.exists(\"/data/scratch/zhteoh/730-tutorial/pretrained_weights\")\n", + "\n", + "bliss_client.load_pretrained_weights_for_survey(survey=\"sdss\", filename=\"sdss_pretrained.pt\")\n", + "\n", + "!ls /data/scratch/zhteoh/730-tutorial/pretrained_weights" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Train on cached generated disk dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Global seed set to 42\n", + "\n", + " from n params module arguments \n", + " 0 -1 1 3328 yolov5.models.common.Conv [2, 64, 5, 1] \n", + " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n", + " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n", + " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n", + " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n", + " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n", + " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n", + " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n", + " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n", + " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n", + " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n", + " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n", + " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n", + " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n", + " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n", + " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n", + " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n", + " 19 [17] 1 16918 yolov5.models.yolo.Detect [17, [[4, 4]], [768]] \n", + "Model summary: 275 layers, 30770326 parameters, 30770326 gradients, 363.2 GFLOPs\n", + "\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n", + "\n", + " | Name | Type | Params\n", + "-------------------------------------------\n", + "0 | model | DetectionModel | 30.8 M\n", + "1 | metrics | BlissMetrics | 0 \n", + "-------------------------------------------\n", + "30.8 M Trainable params\n", + "0 Non-trainable params\n", + "30.8 M Total params\n", + "123.081 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c8e9464eb3b1478fb207cdf66f8ff24c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4921c5dee60a4577880cf5707261fa3f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "00403dedf14d43d0a1925c4d161c2952", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 9, global step 20: 'val/loss' reached 3.31720 (best 3.31720), saving model to '/home/zhteoh/730-astropy-integration/output/version_20/checkpoints/epoch=9-val_loss=3.317.ckpt' as top 1\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b89d7716af134e32966e77cbaed9808f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 19, global step 40: 'val/loss' reached 2.36339 (best 2.36339), saving model to '/home/zhteoh/730-astropy-integration/output/version_20/checkpoints/epoch=19-val_loss=2.363.ckpt' as top 1\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b1c6823cc44a48bf8d9910518baa2a1e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 29, global step 60: 'val/loss' reached 2.16617 (best 2.16617), saving model to '/home/zhteoh/730-astropy-integration/output/version_20/checkpoints/epoch=29-val_loss=2.166.ckpt' as top 1\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5eb492e332654e97b8d1e2a2d0105e85", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 39, global step 80: 'val/loss' reached 2.11654 (best 2.11654), saving model to '/home/zhteoh/730-astropy-integration/output/version_20/checkpoints/epoch=39-val_loss=2.117.ckpt' as top 1\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0c9ab0ecf2ee4b8e82fb00fc74f4d733", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 49, global step 100: 'val/loss' reached 2.10294 (best 2.10294), saving model to '/home/zhteoh/730-astropy-integration/output/version_20/checkpoints/epoch=49-val_loss=2.103.ckpt' as top 1\n", + "`Trainer.fit` stopped: `max_epochs=50` reached.\n" + ] + } + ], + "source": [ + "bliss_client.train_on_cached_data(\n", + " weight_save_path=\"tutorial_encoder/0.pt\",\n", + " train_n_batches=2,\n", + " batch_size=64,\n", + " val_split_file_idxs=[1],\n", + " pretrained_weights_filename=\"sdss_pretrained.pt\",\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run the model" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Using sample dataset" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Get predictions for the sample dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " from n params module arguments \n", + " 0 -1 1 3328 yolov5.models.common.Conv [2, 64, 5, 1] \n", + " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n", + " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n", + " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n", + " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n", + " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n", + " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n", + " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n", + " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n", + " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n", + " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n", + " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n", + " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n", + " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n", + " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n", + " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n", + " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n", + " 19 [17] 1 16918 yolov5.models.yolo.Detect [17, [[4, 4]], [768]] \n", + "Model summary: 275 layers, 30770326 parameters, 30770326 gradients, 363.2 GFLOPs\n", + "\n" + ] + } + ], + "source": [ + "est_cat, est_cat_table, galaxy_params_table, pred_table = bliss_client.predict_sdss(\n", + " data_path=\"data/sdss\", \n", + " weight_save_path=\"tutorial_encoder/0.pt\",\n", + " # predict={\"dataset\": {\"run\": 94, \"camcol\": 1, \"fields\": [12]}}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " Bokeh Plot\n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "bliss_client.plot_predictions_in_notebook()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of entries: 85\n" + ] + }, + { + "data": { + "text/html": [ + "Table length=5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
idxstar_log_fluxesstar_fluxesgalaxy_boolsstar_boolsgalaxy_params [galaxy_flux, galaxy_disk_frac, galaxy_beta_radians, galaxy_disk_q, galaxy_a_d, galaxy_bulge_q, galaxy_a_b]
08.2156553698.399201( 5088, 0.28212, 2.886, 0.42372, 1.2986, 0.4571, 0.57758)
18.2903323985.156501( 5808.9, 0.27946, 2.7233, 0.4484, 1.2998, 0.44768, 0.55768)
28.7449156278.6801( 6868, 0.12816, 2.9083, 0.51797, 1.519, 0.28474, 0.5221)
311.45227794115.4201( 50677, 0.034118, 3.0456, 0.55559, 1.4297, 0.19035, 0.56381)
49.456284512788.28101( 2848.1, 0.20513, 2.8107, 0.48494, 1.5086, 0.30302, 0.67885)
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(\"Number of entries:\", len(est_cat_table))\n", + "est_cat_table.show_in_notebook(display_length=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of entries: 85\n" + ] + }, + { + "data": { + "text/html": [ + "Table length=5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
idxgalaxy_fluxgalaxy_disk_fracgalaxy_beta_radiansgalaxy_disk_qgalaxy_a_dgalaxy_bulge_qgalaxy_a_b
nmgyradarcsecarcsec
05087.96730.282119842.88603450.423721081.29857670.45709880.57758015
15808.8570.27945562.72332330.448399421.29976550.447678740.5576777
26867.96630.128157622.90834980.51796811.51904340.28474130.5220979
350676.60.034117743.04564790.55558911.42966380.190351680.5638064
42848.12520.20513362.8106940.484937041.50864290.303015950.6788529
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(\"Number of entries:\", len(galaxy_params_table))\n", + "galaxy_params_table.show_in_notebook(display_length=5)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Inspect probabilistic predictions\n", + "\n", + "BLISS produces probability distributions on the predicted latent variables." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of entries: 24964\n" + ] + }, + { + "data": { + "text/html": [ + "Table length=5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
idxon_prob_falseon_prob_truestar_log_flux_meanstar_log_flux_stdgalaxy_prob_falsegalaxy_prob_truegalsim_flux_meangalsim_flux_stdgalsim_disk_frac_meangalsim_disk_frac_stdgalsim_beta_radians_meangalsim_beta_radians_stdgalsim_disk_q_meangalsim_disk_q_stdgalsim_a_d_meangalsim_a_d_stdgalsim_bulge_q_meangalsim_bulge_q_stdgalsim_a_b_meangalsim_a_b_std
dex(nmgy)dex(nmgy)nmgynmgyradradarcsecarcsecarcsecarcsec
00.999104260.0008957216.8241350.329000770.013205230.986794776.77296540.402609560.18576361.8083208-0.0019018651.8978670.124551771.86597630.73674680.7742839-0.0444529061.7691844-0.1013972760.69983375
10.99971820.000281835676.3990460.39699620.0026409030.99735916.37655640.395309870.321972131.7079808-0.062718151.69104660.0484154221.73762760.80088590.77584165-0.153431891.7472392-0.187554120.7031531
20.999437030.00056298276.70786950.274870460.0090195540.990980456.71733760.404301760.166057111.747328-0.0297858721.7534314-0.0366179941.83283930.707646370.7884975-0.165706631.8380656-0.210658310.6848675
30.999401330.00059868966.84431550.349018720.0099844340.990015576.8441160.415128860.21666551.8009934-0.066389561.80541370.081091641.90429570.786496160.8000487-0.120503191.8984406-0.171592710.70409834
40.99978230.000217652726.53415780.403667120.0030478240.99695226.47544570.40730670.405217651.7126824-0.126404761.7070930.0238969331.82595310.770179750.7748211-0.157747981.8240426-0.222757820.7017962
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(\"Number of entries:\", len(pred_table))\n", + "pred_table.show_in_notebook(display_length=5)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Save predicted catalog to FITS file" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "est_cat_table.write(\"est_cat.fits\", format=\"fits\", overwrite=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of entries: 85\n" + ] + }, + { + "data": { + "text/html": [ + "Table length=5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
idxstar_log_fluxesstar_fluxesgalaxy_boolsstar_boolsgalaxy_params [galaxy_flux, galaxy_disk_frac, galaxy_beta_radians, galaxy_disk_q, galaxy_a_d, galaxy_bulge_q, galaxy_a_b]
08.2156553698.399201( 5088, 0.28212, 2.886, 0.42372, 1.2986, 0.4571, 0.57758)
18.2903323985.156501( 5808.9, 0.27946, 2.7233, 0.4484, 1.2998, 0.44768, 0.55768)
28.7449156278.6801( 6868, 0.12816, 2.9083, 0.51797, 1.519, 0.28474, 0.5221)
311.45227794115.4201( 50677, 0.034118, 3.0456, 0.55559, 1.4297, 0.19035, 0.56381)
49.456284512788.28101( 2848.1, 0.20513, 2.8107, 0.48494, 1.5086, 0.30302, 0.67885)
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check that catalog is saved as intended\n", + "from astropy.table import Table\n", + "\n", + "est_cat_table = Table.read(\"est_cat.fits\", format=\"fits\")\n", + "print(\"Number of entries:\", len(est_cat_table))\n", + "est_cat_table.show_in_notebook(display_length=5)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Evaluate prediction" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'detection_precision': tensor(0.), 'detection_recall': tensor(0.), 'f1': tensor(nan), 'avg_distance': tensor(78.10789), 'n_matches': tensor(0), 'n_matches_gal_coadd': tensor(0), 'class_acc': tensor(nan), 'gal_tp': tensor(0), 'gal_fp': tensor(0), 'gal_fn': tensor(0), 'gal_tn': tensor(0)}\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "from bliss.metrics import BlissMetrics\n", + "from bliss.surveys.sdss import PhotoFullCatalog\n", + "\n", + "sdss_data_path = \"/data/scratch/zhteoh/730-tutorial/data/sdss\"\n", + "photo_cat = PhotoFullCatalog.from_file(sdss_data_path, run=94, camcol=1, field=12, band=2)\n", + "\n", + "est_cat_cuda = est_cat.to(torch.device(\"cpu\"))\n", + "photo_cat_cuda = photo_cat.to(torch.device(\"cpu\"))\n", + "\n", + "metrics = BlissMetrics()\n", + "results = metrics(est_cat_cuda, photo_cat_cuda)\n", + "\n", + "print(results)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Using user-specified dataset" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Download online dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "run: 2699 camcol: 4 field: 71\n" + ] + } + ], + "source": [ + "from astropy.coordinates import SkyCoord\n", + "from astroquery.sdss import SDSS\n", + "from pathlib import Path\n", + "\n", + "# pos = SkyCoord('0h8m05.63s +14d50m23.3s', frame='icrs') # 1011/3/44\n", + "# pos = SkyCoord(\"1h8m05.73s +13d10m20.3s\", frame=\"icrs\") # 4829/5/27\n", + "pos = SkyCoord(\"1h2m05.83s -2d11m20.3s\", frame=\"icrs\") # 2699/4/71\n", + "region = SDSS.query_region(pos, radius=\"5 arcsec\")\n", + "run, camcol, field = region[\"run\"][0], region[\"camcol\"][0], region[\"field\"][0]\n", + "print(\"run:\", run, \"camcol:\", camcol, \"field:\", field)\n", + "bliss_client.load_survey(\"sdss\", run, camcol, field, download_dir=Path(\"data/sdss\"))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Get predictions for the downloaded dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " from n params module arguments \n", + " 0 -1 1 3328 yolov5.models.common.Conv [2, 64, 5, 1] \n", + " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n", + " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n", + " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n", + " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n", + " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n", + " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n", + " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n", + " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n", + " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n", + " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n", + " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n", + " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n", + " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n", + " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n", + " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n", + " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n", + " 19 [17] 1 16918 yolov5.models.yolo.Detect [17, [[4, 4]], [768]] \n", + "Model summary: 275 layers, 30770326 parameters, 30770326 gradients, 363.2 GFLOPs\n", + "\n" + ] + } + ], + "source": [ + "est_cat_dl, est_cat_table_dl, galaxy_params_table_dl, pred_table_dl = bliss_client.predict_sdss(\n", + " data_path=\"data/sdss\",\n", + " weight_save_path=\"tutorial_encoder/0.pt\",\n", + " predict={\"dataset\": {\"run\": 1011, \"camcol\": 3, \"fields\": [44]}}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'detection_precision': tensor(0.), 'detection_recall': tensor(0.), 'f1': tensor(nan), 'avg_distance': tensor(63.40228), 'n_matches': tensor(0), 'n_matches_gal_coadd': tensor(0), 'class_acc': tensor(nan), 'gal_tp': tensor(0), 'gal_fp': tensor(0), 'gal_fn': tensor(0), 'gal_tn': tensor(0)}\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "from bliss.metrics import BlissMetrics\n", + "from bliss.surveys.sdss import PhotoFullCatalog\n", + "\n", + "sdss_data_path = \"/data/scratch/zhteoh/730-tutorial/data/sdss\"\n", + "photo_cat = PhotoFullCatalog.from_file(sdss_data_path, run=94, camcol=1, field=12, band=2)\n", + "\n", + "est_cat_cuda = est_cat_dl.to(torch.device(\"cpu\"))\n", + "photo_cat_cuda = photo_cat.to(torch.device(\"cpu\"))\n", + "\n", + "metrics = BlissMetrics()\n", + "results = metrics(est_cat_cuda, photo_cat_cuda)\n", + "\n", + "print(results)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " Bokeh Plot\n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "bliss_client.plot_predictions_in_notebook()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of entries: 44\n" + ] + }, + { + "data": { + "text/html": [ + "Table length=5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
idxstar_log_fluxesstar_fluxesgalaxy_boolsstar_boolsgalaxy_params [galaxy_flux, galaxy_disk_frac, galaxy_beta_radians, galaxy_disk_q, galaxy_a_d, galaxy_bulge_q, galaxy_a_b]
dex(nmgy)nmgy
013.136136506933.9710( 2.7742e+06, 0.46441, 3.8738, 0.73526, 0.58455, 0.90716, 0.33317)
17.32205871513.316201( 1896.2, 0.37948, 3.3034, 0.4585, 1.1716, 0.46595, 0.65604)
28.2040573655.750710( 8565.5, 0.3654, 2.574, 0.54229, 1.3557, 0.53009, 0.75448)
37.53414151870.837610( 3435.6, 0.40961, 2.7876, 0.50608, 1.2632, 0.51647, 0.76225)
49.34272911415.51401( 18045, 0.11461, 3.8812, 0.51686, 1.4116, 0.3745, 0.46171)
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(\"Number of entries:\", len(galaxy_params_table_dl))\n", + "est_cat_table_dl.show_in_notebook(display_length=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of entries: 44\n" + ] + }, + { + "data": { + "text/html": [ + "Table length=5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
idxgalaxy_fluxgalaxy_disk_fracgalaxy_beta_radiansgalaxy_disk_qgalaxy_a_dgalaxy_bulge_qgalaxy_a_b
nmgyradarcsecarcsec
02774175.50.464405483.8738080.735257740.584545250.90715910.33316585
11896.22950.379483433.30335740.458500331.17162440.46594670.65603966
28565.4860.36540132.5739720.542292951.35572590.53008710.75447744
33435.59670.40960512.7875640.50607571.26319590.51646720.76224524
418044.5330.114606393.8812030.516858341.41160690.37449910.46170548
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(\"Number of entries:\", len(galaxy_params_table_dl))\n", + "galaxy_params_table_dl.show_in_notebook(display_length=5)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Inspect probabilistic predictions" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of entries: 24964\n" + ] + }, + { + "data": { + "text/html": [ + "Table length=5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
idxon_prob_falseon_prob_truestar_log_flux_meanstar_log_flux_stdgalaxy_prob_falsegalaxy_prob_truegalsim_flux_meangalsim_flux_stdgalsim_disk_frac_meangalsim_disk_frac_stdgalsim_beta_radians_meangalsim_beta_radians_stdgalsim_disk_q_meangalsim_disk_q_stdgalsim_a_d_meangalsim_a_d_stdgalsim_bulge_q_meangalsim_bulge_q_stdgalsim_a_b_meangalsim_a_b_std
dex(nmgy)dex(nmgy)nmgynmgyradradarcsecarcsecarcsecarcsec
00.999570130.00042985036.4476890.337768730.00635105370.993648956.2940580.3667410.286194561.73725820.0472996231.7643875-0.0238287451.73568650.722866060.75291-0.197221281.6656573-0.207115890.6738205
10.999485430.000514561146.51808640.329645870.0112077590.988792246.48607250.395467970.203102831.7870133-0.019657851.7278929-0.0388615131.72390020.70245480.7723759-0.207538841.8130727-0.302320720.72798145
20.999681950.000318062326.41813370.357122240.00465124850.995348756.36766150.400501040.360249041.7220504-0.131026271.6104832-0.106122971.74262330.67262580.74757767-0.19606281.8077326-0.336684230.7270691
30.99982340.000176595876.14815140.371253520.00275814530.997241856.0827190.396538260.391197441.5810724-0.166712761.5083982-0.167770391.66679550.62244560.7258638-0.289596561.7095209-0.379691360.70144725
40.999556060.000443913836.68093970.34111850.0065448880.99345516.613920.389036770.28013231.7113833-0.113075971.7199850.0106115341.84836510.730939150.7584289-0.124616861.8759558-0.262680050.71290445
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(\"Number of entries:\", len(pred_table_dl))\n", + "pred_table_dl.show_in_notebook(display_length=5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/poetry.lock b/poetry.lock index 96b3cbb74..886a3c263 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3199,6 +3199,26 @@ traitlets = ">=5.1" docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] test = ["pep440", "pre-commit", "pytest", "testpath"] +[[package]] +name = "nbsphinx" +version = "0.9.2" +description = "Jupyter Notebook Tools for Sphinx" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "nbsphinx-0.9.2-py3-none-any.whl", hash = "sha256:2746680ece5ad3b0e980639d717a5041a1c1aafb416846b72dfaeecc306bc351"}, + {file = "nbsphinx-0.9.2.tar.gz", hash = "sha256:540db7f4066347f23d0650c4ae8e7d85334c69adf749e030af64c12e996ff88e"}, +] + +[package.dependencies] +docutils = "*" +jinja2 = "*" +nbconvert = "!=5.4" +nbformat = "*" +sphinx = ">=1.8" +traitlets = ">=5" + [[package]] name = "nbstripout" version = "0.6.1" @@ -4113,6 +4133,18 @@ tomlkit = ">=0.10.1" spelling = ["pyenchant (>=3.2,<4.0)"] testutils = ["gitpython (>3)"] +[[package]] +name = "pypandoc" +version = "1.11" +description = "Thin wrapper for pandoc." +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pypandoc-1.11-py3-none-any.whl", hash = "sha256:b260596934e9cfc6513056110a7c8600171d414f90558bf4407e68b209be8007"}, + {file = "pypandoc-1.11.tar.gz", hash = "sha256:7f6d68db0e57e0f6961bec2190897118c4d305fc2d31c22cd16037f22ee084a5"}, +] + [[package]] name = "pyparsing" version = "3.0.9" @@ -6002,4 +6034,4 @@ tqdm = ">=4.64.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "454a2b76a5f5ca6dfb7076c114e5fed2b403dbf7f18b256f4c44a84aa6586b71" +content-hash = "9b3b8708de03e5bb5ab3f510a63b9fe48a1f62991a8d7aaa0549e2082683d73a" diff --git a/pyproject.toml b/pyproject.toml index 4add4b986..135cba54e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,8 @@ sphinx-rtd-theme = ">=0.5.2" torch-tb-profiler = "^0.4.1" tqdm = ">=4.62.3" wemake-python-styleguide = ">=0.16.1" +nbsphinx = "^0.9.2" +pypandoc = "^1.11" [build-system] build-backend = "poetry.core.masonry.api"