diff --git a/enn/networks/__init__.py b/enn/networks/__init__.py index 1b7d05f..fdbf5fc 100644 --- a/enn/networks/__init__.py +++ b/enn/networks/__init__.py @@ -41,6 +41,7 @@ # Einsum MLP from enn.networks.einsum_mlp import EnsembleMLP from enn.networks.einsum_mlp import make_einsum_ensemble_mlp_enn +from enn.networks.einsum_mlp import make_ensemble_mlp_regularized_towards_prior from enn.networks.einsum_mlp import make_ensemble_mlp_with_prior_enn from enn.networks.einsum_mlp import make_ensemble_prior # Ensemble diff --git a/enn/networks/einsum_mlp.py b/enn/networks/einsum_mlp.py index 194493e..106a600 100644 --- a/enn/networks/einsum_mlp.py +++ b/enn/networks/einsum_mlp.py @@ -117,6 +117,47 @@ def apply_with_prior( return networks_base.EnnArray(apply_with_prior, enn.init, enn.indexer) +def make_ensemble_mlp_regularized_towards_prior( + output_sizes: Sequence[int], + dummy_input: chex.Array, + num_ensemble: int, + output_scale: float = 1., + nonzero_bias: bool = True, + seed: int = 999, +) -> networks_base.EnnArray: + """Factory method to create fast einsum MLP ensemble with matched prior. + + Args: + output_sizes: Sequence of integer sizes for the MLPs. + dummy_input: Example x input for prior initialization. + num_ensemble: Integer number of elements in the ensemble. + output_scale: Float rescaling of the output of the network. + nonzero_bias: Whether to make the initial layer bias nonzero. + seed: integer seed for prior init. + + Returns: + EpistemicNetwork ENN of the ensemble of MLP with matches prior. + """ + + enn = make_einsum_ensemble_mlp_enn(output_sizes, num_ensemble, nonzero_bias) + init_key, _ = jax.random.split(jax.random.PRNGKey(seed)) + prior_params, _ = enn.init(init_key, dummy_input, jnp.array([])) + prior_params = jax.lax.stop_gradient(prior_params) + # Apply function selects the appropriate index of the ensemble output. + def apply_with_prior( + params: hk.Params, + state: hk.State, + x: chex.Array, + z: base.Index, + ) -> Tuple[chex.Array, hk.State]: + combined_params = jax.tree_map(lambda p1, p2: p1+p2, params, prior_params) + ensemble_train, state = enn.apply(combined_params, state, x, z) + output = ensemble_train * output_scale + return output, state + + return networks_base.EnnArray(apply_with_prior, enn.init, enn.indexer) + + # TODO(author3): Come up with a better name and use ensembles.py instead. def make_ensemble_prior(output_sizes: Sequence[int], num_ensemble: int,