diff --git a/bliss/api.py b/bliss/api.py
index 24dcb9bdd..5b385ef20 100644
--- a/bliss/api.py
+++ b/bliss/api.py
@@ -1,6 +1,6 @@
import base64
from pathlib import Path
-from typing import Dict, Literal, Optional, Tuple
+from typing import Dict, Literal, Optional, Tuple, TypeAlias
import requests
import torch
@@ -17,10 +17,12 @@
from bliss.surveys import sdss_download
from bliss.train import train as _train
-SurveyType = Literal["decals", "hst", "lsst", "sdss"]
+SurveyType: TypeAlias = Literal["decals", "hst", "lsst", "sdss"]
class BlissClient:
+ """Client for interacting with the BLISS API."""
+
def __init__(self, cwd: str):
self._cwd = cwd
# cached_data_path (str): Path to directory where cached data will be stored.
diff --git a/docs/docsrc/conf.py b/docs/docsrc/conf.py
index 90f014e20..fdc6f293d 100644
--- a/docs/docsrc/conf.py
+++ b/docs/docsrc/conf.py
@@ -66,3 +66,35 @@
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
+
+# For nbsphinx
+nbsphinx_execute = "never"
+
+# Below code is necessary to ensure pandoc is available for use by nbsphinx.
+# See https://stackoverflow.com/questions/62398231/building-docs-fails-due-to-missing-pandoc/71585691#71585691 # noqa: E501 # pylint: disable=line-too-long
+from inspect import getsourcefile # noqa: E402 # pylint: disable=wrong-import-position
+
+# Get path to directory containing this file, conf.py.
+PATH_OF_THIS_FILE = getsourcefile(lambda: 0) # noqa: WPS522
+DOCS_DIRECTORY = os.path.dirname(os.path.abspath(PATH_OF_THIS_FILE)) # type: ignore
+
+
+def ensure_pandoc_installed(_):
+ import pypandoc # pylint: disable=import-outside-toplevel
+
+ # Download pandoc if necessary. If pandoc is already installed and on
+ # the PATH, the installed version will be used. Otherwise, we will
+ # download a copy of pandoc into docs/bin/ and add that to our PATH.
+ pandoc_dir = os.path.join(DOCS_DIRECTORY, "bin")
+ # Add dir containing pandoc binary to the PATH environment variable
+ if pandoc_dir not in os.environ["PATH"].split(os.pathsep):
+ os.environ["PATH"] += os.pathsep + pandoc_dir
+ pypandoc.ensure_pandoc_installed(
+ version="2.11.4",
+ targetfolder=pandoc_dir,
+ delete_installer=True,
+ )
+
+
+def setup(app):
+ app.connect("builder-inited", ensure_pandoc_installed)
diff --git a/docs/docsrc/tutorials/index.rst b/docs/docsrc/tutorials/index.rst
new file mode 100644
index 000000000..d5a2e5ea4
--- /dev/null
+++ b/docs/docsrc/tutorials/index.rst
@@ -0,0 +1,5 @@
+Tutorials
+=========
+
+.. toctree::
+ notebooks/tutorial
diff --git a/docs/docsrc/tutorials/notebooks/est_cat.fits b/docs/docsrc/tutorials/notebooks/est_cat.fits
new file mode 100644
index 000000000..93f6d900e
--- /dev/null
+++ b/docs/docsrc/tutorials/notebooks/est_cat.fits
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f9230ff450904e505d787169070409274a359b8cbd97f86d738ef7e87eefc6a
+size 14400
diff --git a/docs/docsrc/tutorials/notebooks/predict.html b/docs/docsrc/tutorials/notebooks/predict.html
new file mode 100644
index 000000000..8cdc7ec16
--- /dev/null
+++ b/docs/docsrc/tutorials/notebooks/predict.html
@@ -0,0 +1,60 @@
+
+
+
+
+ Bokeh Plot
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/docsrc/tutorials/notebooks/tutorial.ipynb b/docs/docsrc/tutorials/notebooks/tutorial.ipynb
new file mode 100644
index 000000000..6b3c6e88c
--- /dev/null
+++ b/docs/docsrc/tutorials/notebooks/tutorial.ipynb
@@ -0,0 +1,1455 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# BLISS AstroPy Affiliate Package\n",
+ "\n",
+ "Bayesian Light Source Separator (BLISS) is a Bayesian method for deblending and cataloging light sources."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Installation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "env: BLISS_HOME=/home/zhteoh/730-astropy-integration\n"
+ ]
+ }
+ ],
+ "source": [
+ "%env BLISS_HOME=/home/zhteoh/730-astropy-integration"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -e $BLISS_HOME"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Tutorial"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from bliss.api import BlissClient\n",
+ "\n",
+ "bliss_client = BlissClient(cwd=\"/data/scratch/zhteoh/730-tutorial\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Train the model"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Generate synthetic image data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Data will be saved to /data/scratch/zhteoh/730-tutorial/dataset\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Simulating images in batches for file: 100%|██████████| 2/2 [00:55<00:00, 27.61s/it]\n",
+ "Simulating images in batches for file: 100%|██████████| 2/2 [00:57<00:00, 28.74s/it]9s/it]\n",
+ "Generating and writing cached dataset files: 100%|██████████| 2/2 [01:53<00:00, 56.56s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "bliss_client.generate(\n",
+ " n_batches=3, \n",
+ " batch_size=64, \n",
+ " max_images_per_file=128\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Pass additional custom configuration parameters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Data will be saved to /data/scratch/zhteoh/730-tutorial/dataset_ms0.02\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Simulating images in batches for file: 100%|██████████| 2/2 [00:14<00:00, 7.01s/it]\n",
+ "Simulating images in batches for file: 100%|██████████| 2/2 [00:13<00:00, 6.89s/it]9s/it]\n",
+ "Generating and writing cached dataset files: 100%|██████████| 2/2 [00:27<00:00, 13.97s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Alter default cached_data_path\n",
+ "bliss_client.cached_data_path = \"/data/scratch/zhteoh/730-tutorial/dataset_ms0.02\"\n",
+ "\n",
+ "bliss_client.generate(\n",
+ " n_batches=3, # required\n",
+ " batch_size=64, # required\n",
+ " max_images_per_file=128, # required\n",
+ " simulator={\"prior\": {\"mean_sources\": 0.02}}, # optional\n",
+ " generate={\"file_prefix\": \"dataset\"}, # optional\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "dataset_0.pt dataset_1.pt hparams.yaml\n",
+ "18M\t/data/scratch/zhteoh/730-tutorial/dataset_ms0.02\n",
+ "128\n",
+ "torch.Size([1, 80, 80])\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Check that the dataset is generated\n",
+ "!ls /data/scratch/zhteoh/730-tutorial/dataset_ms0.02\n",
+ "!du -sh /data/scratch/zhteoh/730-tutorial/dataset_ms0.02\n",
+ "# !cat /data/scratch/zhteoh/730-tutorial/dataset/hparams.yaml\n",
+ "\n",
+ "dataset_0 = bliss_client.get_dataset_file(filename=\"dataset_0.pt\")\n",
+ "print(len(dataset_0))\n",
+ "print(dataset_0[0][\"images\"].shape)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Train the model"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Without pretrained weights"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "bliss_client.train(weight_save_path=\"tutorial_encoder/0.pt\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### With pretrained weights\n",
+ "\n",
+ "Download our relevant pretrained weights for your sky survey."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "sdss_pretrained.pt sdss.pt sdss.pt.log.json\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "assert os.path.exists(\"/data/scratch/zhteoh/730-tutorial/pretrained_weights\")\n",
+ "\n",
+ "bliss_client.load_pretrained_weights_for_survey(survey=\"sdss\", filename=\"sdss_pretrained.pt\")\n",
+ "\n",
+ "!ls /data/scratch/zhteoh/730-tutorial/pretrained_weights"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Train on cached generated disk dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Global seed set to 42\n",
+ "\n",
+ " from n params module arguments \n",
+ " 0 -1 1 3328 yolov5.models.common.Conv [2, 64, 5, 1] \n",
+ " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n",
+ " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n",
+ " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n",
+ " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n",
+ " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n",
+ " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n",
+ " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n",
+ " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n",
+ " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n",
+ " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n",
+ " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n",
+ " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n",
+ " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n",
+ " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n",
+ " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n",
+ " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n",
+ " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n",
+ " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n",
+ " 19 [17] 1 16918 yolov5.models.yolo.Detect [17, [[4, 4]], [768]] \n",
+ "Model summary: 275 layers, 30770326 parameters, 30770326 gradients, 363.2 GFLOPs\n",
+ "\n",
+ "GPU available: True (cuda), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "-------------------------------------------\n",
+ "0 | model | DetectionModel | 30.8 M\n",
+ "1 | metrics | BlissMetrics | 0 \n",
+ "-------------------------------------------\n",
+ "30.8 M Trainable params\n",
+ "0 Non-trainable params\n",
+ "30.8 M Total params\n",
+ "123.081 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c8e9464eb3b1478fb207cdf66f8ff24c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Sanity Checking: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "4921c5dee60a4577880cf5707261fa3f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Training: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "00403dedf14d43d0a1925c4d161c2952",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Epoch 9, global step 20: 'val/loss' reached 3.31720 (best 3.31720), saving model to '/home/zhteoh/730-astropy-integration/output/version_20/checkpoints/epoch=9-val_loss=3.317.ckpt' as top 1\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b89d7716af134e32966e77cbaed9808f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Epoch 19, global step 40: 'val/loss' reached 2.36339 (best 2.36339), saving model to '/home/zhteoh/730-astropy-integration/output/version_20/checkpoints/epoch=19-val_loss=2.363.ckpt' as top 1\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b1c6823cc44a48bf8d9910518baa2a1e",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Epoch 29, global step 60: 'val/loss' reached 2.16617 (best 2.16617), saving model to '/home/zhteoh/730-astropy-integration/output/version_20/checkpoints/epoch=29-val_loss=2.166.ckpt' as top 1\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "5eb492e332654e97b8d1e2a2d0105e85",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Epoch 39, global step 80: 'val/loss' reached 2.11654 (best 2.11654), saving model to '/home/zhteoh/730-astropy-integration/output/version_20/checkpoints/epoch=39-val_loss=2.117.ckpt' as top 1\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0c9ab0ecf2ee4b8e82fb00fc74f4d733",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Epoch 49, global step 100: 'val/loss' reached 2.10294 (best 2.10294), saving model to '/home/zhteoh/730-astropy-integration/output/version_20/checkpoints/epoch=49-val_loss=2.103.ckpt' as top 1\n",
+ "`Trainer.fit` stopped: `max_epochs=50` reached.\n"
+ ]
+ }
+ ],
+ "source": [
+ "bliss_client.train_on_cached_data(\n",
+ " weight_save_path=\"tutorial_encoder/0.pt\",\n",
+ " train_n_batches=2,\n",
+ " batch_size=64,\n",
+ " val_split_file_idxs=[1],\n",
+ " pretrained_weights_filename=\"sdss_pretrained.pt\",\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Run the model"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Using sample dataset"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Get predictions for the sample dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " from n params module arguments \n",
+ " 0 -1 1 3328 yolov5.models.common.Conv [2, 64, 5, 1] \n",
+ " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n",
+ " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n",
+ " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n",
+ " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n",
+ " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n",
+ " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n",
+ " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n",
+ " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n",
+ " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n",
+ " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n",
+ " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n",
+ " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n",
+ " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n",
+ " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n",
+ " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n",
+ " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n",
+ " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n",
+ " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n",
+ " 19 [17] 1 16918 yolov5.models.yolo.Detect [17, [[4, 4]], [768]] \n",
+ "Model summary: 275 layers, 30770326 parameters, 30770326 gradients, 363.2 GFLOPs\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "est_cat, est_cat_table, galaxy_params_table, pred_table = bliss_client.predict_sdss(\n",
+ " data_path=\"data/sdss\", \n",
+ " weight_save_path=\"tutorial_encoder/0.pt\",\n",
+ " # predict={\"dataset\": {\"run\": 94, \"camcol\": 1, \"fields\": [12]}}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ " \n",
+ " \n",
+ " Bokeh Plot\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "bliss_client.plot_predictions_in_notebook()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of entries: 85\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Table length=5\n",
+ "\n",
+ "idx | star_log_fluxes | star_fluxes | galaxy_bools | star_bools | galaxy_params [galaxy_flux, galaxy_disk_frac, galaxy_beta_radians, galaxy_disk_q, galaxy_a_d, galaxy_bulge_q, galaxy_a_b] |
\n",
+ "0 | 8.215655 | 3698.3992 | 0 | 1 | ( 5088, 0.28212, 2.886, 0.42372, 1.2986, 0.4571, 0.57758) |
\n",
+ "1 | 8.290332 | 3985.1565 | 0 | 1 | ( 5808.9, 0.27946, 2.7233, 0.4484, 1.2998, 0.44768, 0.55768) |
\n",
+ "2 | 8.744915 | 6278.68 | 0 | 1 | ( 6868, 0.12816, 2.9083, 0.51797, 1.519, 0.28474, 0.5221) |
\n",
+ "3 | 11.452277 | 94115.42 | 0 | 1 | ( 50677, 0.034118, 3.0456, 0.55559, 1.4297, 0.19035, 0.56381) |
\n",
+ "4 | 9.4562845 | 12788.281 | 0 | 1 | ( 2848.1, 0.20513, 2.8107, 0.48494, 1.5086, 0.30302, 0.67885) |
\n",
+ "
\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(\"Number of entries:\", len(est_cat_table))\n",
+ "est_cat_table.show_in_notebook(display_length=5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of entries: 85\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Table length=5\n",
+ "\n",
+ "idx | galaxy_flux | galaxy_disk_frac | galaxy_beta_radians | galaxy_disk_q | galaxy_a_d | galaxy_bulge_q | galaxy_a_b |
\n",
+ " | nmgy | | rad | | arcsec | | arcsec |
\n",
+ "0 | 5087.9673 | 0.28211984 | 2.8860345 | 0.42372108 | 1.2985767 | 0.4570988 | 0.57758015 |
\n",
+ "1 | 5808.857 | 0.2794556 | 2.7233233 | 0.44839942 | 1.2997655 | 0.44767874 | 0.5576777 |
\n",
+ "2 | 6867.9663 | 0.12815762 | 2.9083498 | 0.5179681 | 1.5190434 | 0.2847413 | 0.5220979 |
\n",
+ "3 | 50676.6 | 0.03411774 | 3.0456479 | 0.5555891 | 1.4296638 | 0.19035168 | 0.5638064 |
\n",
+ "4 | 2848.1252 | 0.2051336 | 2.810694 | 0.48493704 | 1.5086429 | 0.30301595 | 0.6788529 |
\n",
+ "
\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(\"Number of entries:\", len(galaxy_params_table))\n",
+ "galaxy_params_table.show_in_notebook(display_length=5)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Inspect probabilistic predictions\n",
+ "\n",
+ "BLISS produces probability distributions on the predicted latent variables."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of entries: 24964\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Table length=5\n",
+ "\n",
+ "idx | on_prob_false | on_prob_true | star_log_flux_mean | star_log_flux_std | galaxy_prob_false | galaxy_prob_true | galsim_flux_mean | galsim_flux_std | galsim_disk_frac_mean | galsim_disk_frac_std | galsim_beta_radians_mean | galsim_beta_radians_std | galsim_disk_q_mean | galsim_disk_q_std | galsim_a_d_mean | galsim_a_d_std | galsim_bulge_q_mean | galsim_bulge_q_std | galsim_a_b_mean | galsim_a_b_std |
\n",
+ " | | | dex(nmgy) | dex(nmgy) | | | nmgy | nmgy | | | rad | rad | | | arcsec | arcsec | | | arcsec | arcsec |
\n",
+ "0 | 0.99910426 | 0.000895721 | 6.824135 | 0.32900077 | 0.01320523 | 0.98679477 | 6.7729654 | 0.40260956 | 0.1857636 | 1.8083208 | -0.001901865 | 1.897867 | 0.12455177 | 1.8659763 | 0.7367468 | 0.7742839 | -0.044452906 | 1.7691844 | -0.101397276 | 0.69983375 |
\n",
+ "1 | 0.9997182 | 0.00028183567 | 6.399046 | 0.3969962 | 0.002640903 | 0.9973591 | 6.3765564 | 0.39530987 | 0.32197213 | 1.7079808 | -0.06271815 | 1.6910466 | 0.048415422 | 1.7376276 | 0.8008859 | 0.77584165 | -0.15343189 | 1.7472392 | -0.18755412 | 0.7031531 |
\n",
+ "2 | 0.99943703 | 0.0005629827 | 6.7078695 | 0.27487046 | 0.009019554 | 0.99098045 | 6.7173376 | 0.40430176 | 0.16605711 | 1.747328 | -0.029785872 | 1.7534314 | -0.036617994 | 1.8328393 | 0.70764637 | 0.7884975 | -0.16570663 | 1.8380656 | -0.21065831 | 0.6848675 |
\n",
+ "3 | 0.99940133 | 0.0005986896 | 6.8443155 | 0.34901872 | 0.009984434 | 0.99001557 | 6.844116 | 0.41512886 | 0.2166655 | 1.8009934 | -0.06638956 | 1.8054137 | 0.08109164 | 1.9042957 | 0.78649616 | 0.8000487 | -0.12050319 | 1.8984406 | -0.17159271 | 0.70409834 |
\n",
+ "4 | 0.9997823 | 0.00021765272 | 6.5341578 | 0.40366712 | 0.003047824 | 0.9969522 | 6.4754457 | 0.4073067 | 0.40521765 | 1.7126824 | -0.12640476 | 1.707093 | 0.023896933 | 1.8259531 | 0.77017975 | 0.7748211 | -0.15774798 | 1.8240426 | -0.22275782 | 0.7017962 |
\n",
+ "
\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(\"Number of entries:\", len(pred_table))\n",
+ "pred_table.show_in_notebook(display_length=5)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Save predicted catalog to FITS file"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "est_cat_table.write(\"est_cat.fits\", format=\"fits\", overwrite=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of entries: 85\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Table length=5\n",
+ "\n",
+ "idx | star_log_fluxes | star_fluxes | galaxy_bools | star_bools | galaxy_params [galaxy_flux, galaxy_disk_frac, galaxy_beta_radians, galaxy_disk_q, galaxy_a_d, galaxy_bulge_q, galaxy_a_b] |
\n",
+ "0 | 8.215655 | 3698.3992 | 0 | 1 | ( 5088, 0.28212, 2.886, 0.42372, 1.2986, 0.4571, 0.57758) |
\n",
+ "1 | 8.290332 | 3985.1565 | 0 | 1 | ( 5808.9, 0.27946, 2.7233, 0.4484, 1.2998, 0.44768, 0.55768) |
\n",
+ "2 | 8.744915 | 6278.68 | 0 | 1 | ( 6868, 0.12816, 2.9083, 0.51797, 1.519, 0.28474, 0.5221) |
\n",
+ "3 | 11.452277 | 94115.42 | 0 | 1 | ( 50677, 0.034118, 3.0456, 0.55559, 1.4297, 0.19035, 0.56381) |
\n",
+ "4 | 9.4562845 | 12788.281 | 0 | 1 | ( 2848.1, 0.20513, 2.8107, 0.48494, 1.5086, 0.30302, 0.67885) |
\n",
+ "
\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Check that catalog is saved as intended\n",
+ "from astropy.table import Table\n",
+ "\n",
+ "est_cat_table = Table.read(\"est_cat.fits\", format=\"fits\")\n",
+ "print(\"Number of entries:\", len(est_cat_table))\n",
+ "est_cat_table.show_in_notebook(display_length=5)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Evaluate prediction"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'detection_precision': tensor(0.), 'detection_recall': tensor(0.), 'f1': tensor(nan), 'avg_distance': tensor(78.10789), 'n_matches': tensor(0), 'n_matches_gal_coadd': tensor(0), 'class_acc': tensor(nan), 'gal_tp': tensor(0), 'gal_fp': tensor(0), 'gal_fn': tensor(0), 'gal_tn': tensor(0)}\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "from bliss.metrics import BlissMetrics\n",
+ "from bliss.surveys.sdss import PhotoFullCatalog\n",
+ "\n",
+ "sdss_data_path = \"/data/scratch/zhteoh/730-tutorial/data/sdss\"\n",
+ "photo_cat = PhotoFullCatalog.from_file(sdss_data_path, run=94, camcol=1, field=12, band=2)\n",
+ "\n",
+ "est_cat_cuda = est_cat.to(torch.device(\"cpu\"))\n",
+ "photo_cat_cuda = photo_cat.to(torch.device(\"cpu\"))\n",
+ "\n",
+ "metrics = BlissMetrics()\n",
+ "results = metrics(est_cat_cuda, photo_cat_cuda)\n",
+ "\n",
+ "print(results)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Using user-specified dataset"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Download online dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "run: 2699 camcol: 4 field: 71\n"
+ ]
+ }
+ ],
+ "source": [
+ "from astropy.coordinates import SkyCoord\n",
+ "from astroquery.sdss import SDSS\n",
+ "from pathlib import Path\n",
+ "\n",
+ "# pos = SkyCoord('0h8m05.63s +14d50m23.3s', frame='icrs') # 1011/3/44\n",
+ "# pos = SkyCoord(\"1h8m05.73s +13d10m20.3s\", frame=\"icrs\") # 4829/5/27\n",
+ "pos = SkyCoord(\"1h2m05.83s -2d11m20.3s\", frame=\"icrs\") # 2699/4/71\n",
+ "region = SDSS.query_region(pos, radius=\"5 arcsec\")\n",
+ "run, camcol, field = region[\"run\"][0], region[\"camcol\"][0], region[\"field\"][0]\n",
+ "print(\"run:\", run, \"camcol:\", camcol, \"field:\", field)\n",
+ "bliss_client.load_survey(\"sdss\", run, camcol, field, download_dir=Path(\"data/sdss\"))"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Get predictions for the downloaded dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " from n params module arguments \n",
+ " 0 -1 1 3328 yolov5.models.common.Conv [2, 64, 5, 1] \n",
+ " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n",
+ " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n",
+ " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n",
+ " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n",
+ " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n",
+ " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n",
+ " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n",
+ " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n",
+ " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n",
+ " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n",
+ " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n",
+ " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n",
+ " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n",
+ " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n",
+ " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n",
+ " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n",
+ " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n",
+ " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n",
+ " 19 [17] 1 16918 yolov5.models.yolo.Detect [17, [[4, 4]], [768]] \n",
+ "Model summary: 275 layers, 30770326 parameters, 30770326 gradients, 363.2 GFLOPs\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "est_cat_dl, est_cat_table_dl, galaxy_params_table_dl, pred_table_dl = bliss_client.predict_sdss(\n",
+ " data_path=\"data/sdss\",\n",
+ " weight_save_path=\"tutorial_encoder/0.pt\",\n",
+ " predict={\"dataset\": {\"run\": 1011, \"camcol\": 3, \"fields\": [44]}}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'detection_precision': tensor(0.), 'detection_recall': tensor(0.), 'f1': tensor(nan), 'avg_distance': tensor(63.40228), 'n_matches': tensor(0), 'n_matches_gal_coadd': tensor(0), 'class_acc': tensor(nan), 'gal_tp': tensor(0), 'gal_fp': tensor(0), 'gal_fn': tensor(0), 'gal_tn': tensor(0)}\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "from bliss.metrics import BlissMetrics\n",
+ "from bliss.surveys.sdss import PhotoFullCatalog\n",
+ "\n",
+ "sdss_data_path = \"/data/scratch/zhteoh/730-tutorial/data/sdss\"\n",
+ "photo_cat = PhotoFullCatalog.from_file(sdss_data_path, run=94, camcol=1, field=12, band=2)\n",
+ "\n",
+ "est_cat_cuda = est_cat_dl.to(torch.device(\"cpu\"))\n",
+ "photo_cat_cuda = photo_cat.to(torch.device(\"cpu\"))\n",
+ "\n",
+ "metrics = BlissMetrics()\n",
+ "results = metrics(est_cat_cuda, photo_cat_cuda)\n",
+ "\n",
+ "print(results)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ " \n",
+ " \n",
+ " Bokeh Plot\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "bliss_client.plot_predictions_in_notebook()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of entries: 44\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Table length=5\n",
+ "\n",
+ "idx | star_log_fluxes | star_fluxes | galaxy_bools | star_bools | galaxy_params [galaxy_flux, galaxy_disk_frac, galaxy_beta_radians, galaxy_disk_q, galaxy_a_d, galaxy_bulge_q, galaxy_a_b] |
\n",
+ " | dex(nmgy) | nmgy | | | |
\n",
+ "0 | 13.136136 | 506933.97 | 1 | 0 | ( 2.7742e+06, 0.46441, 3.8738, 0.73526, 0.58455, 0.90716, 0.33317) |
\n",
+ "1 | 7.3220587 | 1513.3162 | 0 | 1 | ( 1896.2, 0.37948, 3.3034, 0.4585, 1.1716, 0.46595, 0.65604) |
\n",
+ "2 | 8.204057 | 3655.7507 | 1 | 0 | ( 8565.5, 0.3654, 2.574, 0.54229, 1.3557, 0.53009, 0.75448) |
\n",
+ "3 | 7.5341415 | 1870.8376 | 1 | 0 | ( 3435.6, 0.40961, 2.7876, 0.50608, 1.2632, 0.51647, 0.76225) |
\n",
+ "4 | 9.342729 | 11415.514 | 0 | 1 | ( 18045, 0.11461, 3.8812, 0.51686, 1.4116, 0.3745, 0.46171) |
\n",
+ "
\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(\"Number of entries:\", len(galaxy_params_table_dl))\n",
+ "est_cat_table_dl.show_in_notebook(display_length=5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of entries: 44\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Table length=5\n",
+ "\n",
+ "idx | galaxy_flux | galaxy_disk_frac | galaxy_beta_radians | galaxy_disk_q | galaxy_a_d | galaxy_bulge_q | galaxy_a_b |
\n",
+ " | nmgy | | rad | | arcsec | | arcsec |
\n",
+ "0 | 2774175.5 | 0.46440548 | 3.873808 | 0.73525774 | 0.58454525 | 0.9071591 | 0.33316585 |
\n",
+ "1 | 1896.2295 | 0.37948343 | 3.3033574 | 0.45850033 | 1.1716244 | 0.4659467 | 0.65603966 |
\n",
+ "2 | 8565.486 | 0.3654013 | 2.573972 | 0.54229295 | 1.3557259 | 0.5300871 | 0.75447744 |
\n",
+ "3 | 3435.5967 | 0.4096051 | 2.787564 | 0.5060757 | 1.2631959 | 0.5164672 | 0.76224524 |
\n",
+ "4 | 18044.533 | 0.11460639 | 3.881203 | 0.51685834 | 1.4116069 | 0.3744991 | 0.46170548 |
\n",
+ "
\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(\"Number of entries:\", len(galaxy_params_table_dl))\n",
+ "galaxy_params_table_dl.show_in_notebook(display_length=5)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Inspect probabilistic predictions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of entries: 24964\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Table length=5\n",
+ "\n",
+ "idx | on_prob_false | on_prob_true | star_log_flux_mean | star_log_flux_std | galaxy_prob_false | galaxy_prob_true | galsim_flux_mean | galsim_flux_std | galsim_disk_frac_mean | galsim_disk_frac_std | galsim_beta_radians_mean | galsim_beta_radians_std | galsim_disk_q_mean | galsim_disk_q_std | galsim_a_d_mean | galsim_a_d_std | galsim_bulge_q_mean | galsim_bulge_q_std | galsim_a_b_mean | galsim_a_b_std |
\n",
+ " | | | dex(nmgy) | dex(nmgy) | | | nmgy | nmgy | | | rad | rad | | | arcsec | arcsec | | | arcsec | arcsec |
\n",
+ "0 | 0.99957013 | 0.0004298503 | 6.447689 | 0.33776873 | 0.0063510537 | 0.99364895 | 6.294058 | 0.366741 | 0.28619456 | 1.7372582 | 0.047299623 | 1.7643875 | -0.023828745 | 1.7356865 | 0.72286606 | 0.75291 | -0.19722128 | 1.6656573 | -0.20711589 | 0.6738205 |
\n",
+ "1 | 0.99948543 | 0.00051456114 | 6.5180864 | 0.32964587 | 0.011207759 | 0.98879224 | 6.4860725 | 0.39546797 | 0.20310283 | 1.7870133 | -0.01965785 | 1.7278929 | -0.038861513 | 1.7239002 | 0.7024548 | 0.7723759 | -0.20753884 | 1.8130727 | -0.30232072 | 0.72798145 |
\n",
+ "2 | 0.99968195 | 0.00031806232 | 6.4181337 | 0.35712224 | 0.0046512485 | 0.99534875 | 6.3676615 | 0.40050104 | 0.36024904 | 1.7220504 | -0.13102627 | 1.6104832 | -0.10612297 | 1.7426233 | 0.6726258 | 0.74757767 | -0.1960628 | 1.8077326 | -0.33668423 | 0.7270691 |
\n",
+ "3 | 0.9998234 | 0.00017659587 | 6.1481514 | 0.37125352 | 0.0027581453 | 0.99724185 | 6.082719 | 0.39653826 | 0.39119744 | 1.5810724 | -0.16671276 | 1.5083982 | -0.16777039 | 1.6667955 | 0.6224456 | 0.7258638 | -0.28959656 | 1.7095209 | -0.37969136 | 0.70144725 |
\n",
+ "4 | 0.99955606 | 0.00044391383 | 6.6809397 | 0.3411185 | 0.006544888 | 0.9934551 | 6.61392 | 0.38903677 | 0.2801323 | 1.7113833 | -0.11307597 | 1.719985 | 0.010611534 | 1.8483651 | 0.73093915 | 0.7584289 | -0.12461686 | 1.8759558 | -0.26268005 | 0.71290445 |
\n",
+ "
\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(\"Number of entries:\", len(pred_table_dl))\n",
+ "pred_table_dl.show_in_notebook(display_length=5)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/poetry.lock b/poetry.lock
index 96b3cbb74..886a3c263 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -3199,6 +3199,26 @@ traitlets = ">=5.1"
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"]
test = ["pep440", "pre-commit", "pytest", "testpath"]
+[[package]]
+name = "nbsphinx"
+version = "0.9.2"
+description = "Jupyter Notebook Tools for Sphinx"
+category = "dev"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "nbsphinx-0.9.2-py3-none-any.whl", hash = "sha256:2746680ece5ad3b0e980639d717a5041a1c1aafb416846b72dfaeecc306bc351"},
+ {file = "nbsphinx-0.9.2.tar.gz", hash = "sha256:540db7f4066347f23d0650c4ae8e7d85334c69adf749e030af64c12e996ff88e"},
+]
+
+[package.dependencies]
+docutils = "*"
+jinja2 = "*"
+nbconvert = "!=5.4"
+nbformat = "*"
+sphinx = ">=1.8"
+traitlets = ">=5"
+
[[package]]
name = "nbstripout"
version = "0.6.1"
@@ -4113,6 +4133,18 @@ tomlkit = ">=0.10.1"
spelling = ["pyenchant (>=3.2,<4.0)"]
testutils = ["gitpython (>3)"]
+[[package]]
+name = "pypandoc"
+version = "1.11"
+description = "Thin wrapper for pandoc."
+category = "dev"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "pypandoc-1.11-py3-none-any.whl", hash = "sha256:b260596934e9cfc6513056110a7c8600171d414f90558bf4407e68b209be8007"},
+ {file = "pypandoc-1.11.tar.gz", hash = "sha256:7f6d68db0e57e0f6961bec2190897118c4d305fc2d31c22cd16037f22ee084a5"},
+]
+
[[package]]
name = "pyparsing"
version = "3.0.9"
@@ -6002,4 +6034,4 @@ tqdm = ">=4.64.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "454a2b76a5f5ca6dfb7076c114e5fed2b403dbf7f18b256f4c44a84aa6586b71"
+content-hash = "9b3b8708de03e5bb5ab3f510a63b9fe48a1f62991a8d7aaa0549e2082683d73a"
diff --git a/pyproject.toml b/pyproject.toml
index 4add4b986..135cba54e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -67,6 +67,8 @@ sphinx-rtd-theme = ">=0.5.2"
torch-tb-profiler = "^0.4.1"
tqdm = ">=4.62.3"
wemake-python-styleguide = ">=0.16.1"
+nbsphinx = "^0.9.2"
+pypandoc = "^1.11"
[build-system]
build-backend = "poetry.core.masonry.api"