Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ again.

All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
[GitHub Help](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) for more
information on using pull requests.
39 changes: 15 additions & 24 deletions enn/colabs/enn_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,25 +104,19 @@
"\n",
"import warnings\n",
"\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#@title Development imports\n",
"from typing import Callable, NamedTuple\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import plotnine as gg\n",
"\n",
"from acme.utils.loggers.terminal import TerminalLogger\n",
"import dataclasses\n",
"import chex\n",
"import haiku as hk\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds"
"import optax"
]
},
{
Expand All @@ -135,15 +129,11 @@
"outputs": [],
"source": [
"#@title ENN imports\n",
"import enn\n",
"from enn import losses\n",
"from enn import networks\n",
"from enn import supervised\n",
"from enn import base\n",
"from enn import data_noise\n",
"from enn import utils\n",
"from enn.supervised import classification_data\n",
"from enn.supervised import regression_data\n"
"from enn.supervised import classification_data, regression_data"
]
},
{
Expand Down Expand Up @@ -173,6 +163,7 @@
" learning_rate: float = 1e-3\n",
" noise_std: float = 0.1\n",
"\n",
"\n",
"FLAGS = Config()"
]
},
Expand Down Expand Up @@ -202,7 +193,7 @@
"# Logger\n",
"logger = TerminalLogger('supervised_regression')\n",
"\n",
"# Create Ensemble ENN with a prior network \n",
"# Create Ensemble ENN with a prior network\n",
"enn = networks.MLPEnsembleMatchedPrior(\n",
" output_sizes=[50, 50, 1],\n",
" dummy_input=next(dataset).x,\n",
Expand All @@ -211,11 +202,11 @@
" seed=FLAGS.seed,\n",
")\n",
"\n",
"# L2 loss on perturbed outputs \n",
"# L2 loss on perturbed outputs\n",
"noise_fn = data_noise.GaussianTargetNoise(enn, FLAGS.noise_std, FLAGS.seed)\n",
"single_loss = losses.add_data_noise(losses.L2Loss(), noise_fn)\n",
"loss_fn = losses.average_single_index_loss(single_loss, FLAGS.num_index_samples)\n",
" \n",
"\n",
"# Optimizer\n",
"optimizer = optax.adam(FLAGS.learning_rate)\n",
"\n",
Expand Down
30 changes: 10 additions & 20 deletions enn/colabs/epinet_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,26 +78,20 @@
"\n",
"import warnings\n",
"\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#@title Development imports\n",
"from typing import Callable, NamedTuple\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import plotnine as gg\n",
"\n",
"from acme.utils.loggers.terminal import TerminalLogger\n",
"import dataclasses\n",
"import chex\n",
"import haiku as hk\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"import dill\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds"
"import dill"
]
},
{
Expand All @@ -110,12 +104,8 @@
"outputs": [],
"source": [
"#@title ENN imports\n",
"import enn\n",
"from enn import datasets\n",
"from enn.checkpoints import base as checkpoint_base\n",
"from enn.networks.epinet import base as epinet_base\n",
"from enn.checkpoints import utils\n",
"from enn.checkpoints import imagenet\n",
"from enn.checkpoints import catalog\n",
"from enn import metrics as enn_metrics"
]
Expand Down