Skip to content

Add ensemble with regularization towards initial parameters. #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions enn/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
# Einsum MLP
from enn.networks.einsum_mlp import EnsembleMLP
from enn.networks.einsum_mlp import make_einsum_ensemble_mlp_enn
from enn.networks.einsum_mlp import make_ensemble_mlp_regularized_towards_prior
from enn.networks.einsum_mlp import make_ensemble_mlp_with_prior_enn
from enn.networks.einsum_mlp import make_ensemble_prior
# Ensemble
Expand Down
41 changes: 41 additions & 0 deletions enn/networks/einsum_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,47 @@ def apply_with_prior(
return networks_base.EnnArray(apply_with_prior, enn.init, enn.indexer)


def make_ensemble_mlp_regularized_towards_prior(
output_sizes: Sequence[int],
dummy_input: chex.Array,
num_ensemble: int,
output_scale: float = 1.,
nonzero_bias: bool = True,
seed: int = 999,
) -> networks_base.EnnArray:
"""Factory method to create fast einsum MLP ensemble with matched prior.

Args:
output_sizes: Sequence of integer sizes for the MLPs.
dummy_input: Example x input for prior initialization.
num_ensemble: Integer number of elements in the ensemble.
output_scale: Float rescaling of the output of the network.
nonzero_bias: Whether to make the initial layer bias nonzero.
seed: integer seed for prior init.

Returns:
EpistemicNetwork ENN of the ensemble of MLP with matches prior.
"""

enn = make_einsum_ensemble_mlp_enn(output_sizes, num_ensemble, nonzero_bias)
init_key, _ = jax.random.split(jax.random.PRNGKey(seed))
prior_params, _ = enn.init(init_key, dummy_input, jnp.array([]))
prior_params = jax.lax.stop_gradient(prior_params)
# Apply function selects the appropriate index of the ensemble output.
def apply_with_prior(
params: hk.Params,
state: hk.State,
x: chex.Array,
z: base.Index,
) -> Tuple[chex.Array, hk.State]:
combined_params = jax.tree_map(lambda p1, p2: p1+p2, params, prior_params)
ensemble_train, state = enn.apply(combined_params, state, x, z)
output = ensemble_train * output_scale
return output, state

return networks_base.EnnArray(apply_with_prior, enn.init, enn.indexer)


# TODO(author3): Come up with a better name and use ensembles.py instead.
def make_ensemble_prior(output_sizes: Sequence[int],
num_ensemble: int,
Expand Down