diff --git a/disentangled_rnns/example.py b/disentangled_rnns/example.py index fb5dc84..75bfe9f 100644 --- a/disentangled_rnns/example.py +++ b/disentangled_rnns/example.py @@ -15,9 +15,12 @@ """Example DisRNN workflow: Define a dataset, train network, inspect the fit. """ +import copy + from absl import app from absl import flags from disentangled_rnns.library import disrnn +from disentangled_rnns.library import plotting from disentangled_rnns.library import rnn_utils from disentangled_rnns.library import two_armed_bandits import optax @@ -64,18 +67,18 @@ def main(_) -> None: ) # Define the disRNN architecture - update_mlp_shape = (3, 5, 5) - choice_mlp_shape = (2, 2) - latent_size = 5 - - def make_network(): - return disrnn.HkDisRNN( - update_mlp_shape=update_mlp_shape, - choice_mlp_shape=choice_mlp_shape, - latent_size=latent_size, - obs_size=2, - target_size=2, - ) + disrnn_config = disrnn.DisRnnConfig( + obs_size=2, + output_size=2, + latent_size=5, + update_net_n_units_per_layer=8, + update_net_n_layers=4, + choice_net_n_units_per_layer=4, + choice_net_n_layers=2, + x_names=dataset.x_names, + y_names=dataset.y_names, + ) + make_disrnn = lambda: disrnn.HkDisentangledRNN(disrnn_config) opt = optax.adam(learning_rate=FLAGS.learning_rate) @@ -85,7 +88,7 @@ def make_network(): # Warmup training with no information penalty params, _, _ = rnn_utils.train_network( - make_network, + make_disrnn, training_dataset=dataset, validation_dataset=dataset_eval, loss="penalized_categorical", @@ -99,7 +102,7 @@ def make_network(): # Additional training using information penalty params, _, _ = rnn_utils.train_network( - make_network, + make_disrnn, training_dataset=dataset, validation_dataset=dataset_eval, loss="penalized_categorical", @@ -114,28 +117,20 @@ def make_network(): ########################### # Inspecting a fit disRNN # ########################### - - # Eval mode runs the network with no noise - def make_network_eval(): - return disrnn.HkDisRNN( - update_mlp_shape=update_mlp_shape, - choice_mlp_shape=choice_mlp_shape, - latent_size=latent_size, - obs_size=2, - target_size=2, - eval_mode=True, - ) - - disrnn.plot_bottlenecks(params, make_network_eval) - disrnn.plot_update_rules(params, make_network_eval) + # Plot bottleneck structure and update rules + plotting.plot_bottlenecks(params, disrnn_config) + plotting.plot_update_rules(params, disrnn_config) ############################## # Eval disRNN on unseen data # ############################## - + # Run the network in noiseless mode to see evolution of states over time + config_noiseless = copy.deepcopy(disrnn_config) + config_noiseless.noiseless_mode = True + make_noiseless_disrnn = lambda: disrnn.HkDisentangledRNN(config_noiseless) xs, _ = next(dataset_eval) # pylint: disable-next=unused-variable - _, network_states = rnn_utils.eval_network(make_network_eval, params, xs) + _, network_states = rnn_utils.eval_network(make_noiseless_disrnn, params, xs) if __name__ == "__main__": diff --git a/disentangled_rnns/library/disrnn.py b/disentangled_rnns/library/disrnn.py index 09a2535..79258f0 100644 --- a/disentangled_rnns/library/disrnn.py +++ b/disentangled_rnns/library/disrnn.py @@ -13,39 +13,58 @@ # limitations under the License. """Disentangled RNN and plotting functions.""" +import dataclasses +from typing import Optional, Callable, Any, Sequence -from typing import Any, Callable, Iterable - -from disentangled_rnns.library import rnn_utils import haiku as hk import jax import jax.numpy as jnp -import matplotlib as mpl -import matplotlib.pyplot as plt import numpy as np -def kl_gaussian(mean: jnp.ndarray, var: jnp.ndarray) -> jnp.ndarray: - r"""Calculate KL divergence between given and standard gaussian distributions. +def information_bottleneck( + mus: jnp.ndarray, sigmas: jnp.ndarray +) -> tuple[jnp.ndarray, jnp.ndarray]: + r"""Output from an information bottleneck given a vector of means and vars. + + Bottleneck outputs are sampled independently from Gaussian distributions with + the given means and variances. Bottleneck costs are computed as the KL + divergence between this sampling distribution and the unit Gaussian. KL(p, q) = H(p, q) - H(p) = -\int p(x)log(q(x))dx - -\int p(x)log(p(x))dx - = 0.5 * [log(|s2|/|s1|) - 1 + tr(s1/s2) + (m1-m2)^2/s2] - = 0.5 * [-log(|s1|) - 1 + tr(s1) + m1^2] (if m2 = 0, s2 = 1) - Args: - mean: mean vector of the first distribution - var: diagonal vector of covariance matrix of the first distribution + = 0.5 * [log(|s2|/|s1|) - 1 + tr(s1/s2) + (m1-m2)^2/s2] + = 0.5 * [-log(|s1|) - 1 + tr(s1) + m1^2] (if m2 = 0, s2 = 1) + Args: + mus: The means of the sampling distribution. Shape is (batch_size, + bottleneck_dims) + sigmas: The diagonal of the covariance matrix of the sampling distribution. + Shape is (bottleneck_dims) Returns: - A scalar representing KL divergence of the two Gaussian distributions. + bottleneck_output: The noisy output of the bottleneck. Shape is the same as + mus. + bottleneck_penalty: The KL cost of the bottleneck sample. Shape is + (batch_size,). """ + # Shape is (batch_size, bottleneck_dims) + bottleneck_output = mus + sigmas * jax.random.normal( + hk.next_rng_key(), jnp.shape(mus) + ) - return 0.5 * jnp.sum(-jnp.log(var) - 1.0 + var + jnp.square(mean), axis=-1) + # Shape is (batch_size, bottleneck_dims) + elementwise_kl = jnp.square(mus) + sigmas - 1.0 - jnp.log(sigmas) + # Shape is (batch_size,) + bottleneck_penalty = 0.5 * jnp.sum( + elementwise_kl, axis=tuple(range(1, elementwise_kl.ndim)) + ) + + return bottleneck_output, bottleneck_penalty def reparameterize_sigma( hk_param: jnp.ndarray, min_sigma: float = 1e-5 ) -> jnp.ndarray: - """Reparameterizes bottleneck sigma for easy fitting. + """Reparamaterizes bottleneck sigma for easy fitting. Args: hk_param: The haiku parameter corresponding to a bottleneck sigma. Range @@ -57,363 +76,492 @@ def reparameterize_sigma( return jnp.abs(hk_param) + min_sigma -class HkDisRNN(hk.RNNCore): - """Disentangled RNN.""" +@dataclasses.dataclass +class DisRnnConfig: + """Specifies an architecture and configuration for a Disentangled RNN. + + Attributes: + obs_size: Number of dimensions in the observation vector + output_size: Number of dimensions the disRNN will output + (logits or predicted targets) + latent_size: Number of recurrent variables + update_net_n_units_per_layer: Number of units in each layer of the update + networks + update_net_n_layers: Number of layers in the update networks + choice_net_n_units_per_layer: Number of units in each layer of the choice + network + choice_net_n_layers: Number of layers in the choice network + noiseless_mode: Allows turning off the bottlenecks e.g. for evaluation + latent_penalty_scale: Multiplier for KL cost on the latent bottlenecks + choice_net_penalty_scale: Multiplier for KL cost on choice net bottlenecks + update_net_penalty_scale: Multiplier for KL cost on update net bottlenecks + l2_scale: Multiplier for L2 penalty on hidden layer weights in both update + and choice networks + activation: String defining an activation function. Must be in jax.nn. + max_latent_value: Cap on the possible absolute value of a latent. Used to + prevent runaway latents resulting in NaNs + + x_names: Names of the observation vector elements. Must have length obs_size + y_names: Names of the target vector elements. Must have length target_size + """ - def __init__( - self, - obs_size: int = 2, - target_size: int = 1, - latent_size: int = 10, - update_mlp_shape: Iterable[int] = (10, 10, 10), - choice_mlp_shape: Iterable[int] = (10, 10, 10), - eval_mode: float = 0, - beta_scale: int = 1, - activation: Callable[[Any], Any] = jax.nn.relu, - ): - super().__init__() + obs_size: int = 2 + output_size: int = 2 + latent_size: int = 10 + + update_net_n_units_per_layer: int = 10 + update_net_n_layers: int = 2 + choice_net_n_units_per_layer: int = 2 + choice_net_n_layers: int = 2 + activation: str = 'leaky_relu' + + noiseless_mode: bool = False + + latent_penalty_scale: float = 1.0 + choice_net_penalty_scale: float = 0.1 + update_net_penalty_scale: float = 0.1 + l2_scale: float = 0.01 + + max_latent_value: float = 2. + + x_names: Optional[list[str]] = None + y_names: Optional[list[str]] = None + + def __post_init__(self): + """Checks that the configuration is valid.""" + + expected_len_xnames = self.obs_size + if self.x_names is None: + self.x_names = [f'Observation {i}' for i in range(expected_len_xnames)] + if len(self.x_names) != expected_len_xnames: + raise ValueError( + f'Based on x_names {self.x_names}, expected obs_size to be ' + f'{expected_len_xnames} but got {self.obs_size}' + ) + + # Check activation is in jax.nn + try: + getattr(jax.nn, self.activation) + except AttributeError as e: + raise ValueError( + f'Activation {self.activation} not found in jax.nn. Provided value ' + f'was {self.activation}' + ) from e + + +class ResMLP(hk.Module): + """MLP modified to apply serial updates to a residual stream. + + Attributes: + input_size: Dimension of the input vector + output_size: Dimensions of the output vector + n_layers: Number of layers + n_units_per_layer: Dimension of the stream and of each layer + activation_fn: Activation function + w_initializer: Initializer for the weights + b_initializer: Initializer for the biases + name: Optional name, which affects the names of the haiku parameters + """ - self._target_size = target_size - self._latent_size = latent_size - self._update_mlp_shape = update_mlp_shape - self._choice_mlp_shape = choice_mlp_shape - self._beta_scale = beta_scale - self._eval_mode = eval_mode - self._activation = activation - - # Each update MLP gets input from both the latents and the observations. - # It has a sigma and a multiplier associated with each. - mlp_input_size = latent_size + obs_size - # At init the bottlenecks should all be open: sigmas small and multipliers 1 - update_mlp_sigma_params = hk.get_parameter( - 'update_mlp_sigma_params', - (mlp_input_size, latent_size), - init=hk.initializers.RandomUniform(minval=-3, maxval=-2), + def __init__(self, + input_size: int, + output_size: int, + n_layers: int = 5, + n_units_per_layer: int = 5, + activation_fn: Callable[[Any], Any] = jax.nn.relu, + name=None): + super().__init__(name=name) + + self.n_layers = n_layers + self.n_units_per_layer = n_units_per_layer + self.activation_fn = activation_fn + + # Input layer will be a linear projection from input size to stream size + # To keep activation magnitudes similar, we initialise weights in the range + # 1/sqrt(input_size) + scale = 1 / jnp.sqrt(jnp.float32(input_size)) + self._input_weights = hk.get_parameter( + 'input_weights', + (input_size, self.n_units_per_layer), + init=hk.initializers.RandomNormal(stddev=scale), ) - # Reparameterize the sigmas to be positive definite and have a minimum value - self._update_mlp_sigmas = reparameterize_sigma( - update_mlp_sigma_params - ) * (1 - eval_mode) - self._update_mlp_multipliers = hk.get_parameter( - 'update_mlp_gates', - (mlp_input_size, latent_size), - init=hk.initializers.Constant(constant=1), + self._input_biases = hk.get_parameter( + 'input_biases', + (self.n_units_per_layer,), + init=hk.initializers.Constant(0.0), ) - # Latents will also go through a bottleneck - self.latent_sigma_params = hk.get_parameter( - 'latent_sigma_params', - (latent_size,), - init=hk.initializers.RandomUniform(minval=-3, maxval=-2), + # Output layer will be a linear projection from stream size to output size, + # To keep activation magnitudes similar, we scale initial weights by + # 1 / sqrt(n_units_per_layer) + scale = 1 / jnp.sqrt(jnp.float32(self.n_units_per_layer)) + self._output_weights = hk.get_parameter( + 'output_weights', + (n_units_per_layer, output_size), + init=hk.initializers.RandomNormal(stddev=scale), ) - self._latent_sigmas = reparameterize_sigma( - self.latent_sigma_params - ) * (1 - eval_mode) - - # Latent initial values are also free parameters - self._latent_inits = hk.get_parameter( - 'latent_inits', - (latent_size,), - init=hk.initializers.RandomUniform(minval=-1, maxval=1), + self._output_biases = hk.get_parameter( + 'output_biases', (output_size,), init=hk.initializers.Constant(0.0) ) - def __call__(self, observations: jnp.ndarray, prev_latents: jnp.ndarray): - penalty = 0 # Accumulator for KL costs - - ################ - # UPDATE MLPs # - ################ - # Each update MLP updates one latent - # It sees previous latents and current observation - # It outputs a weight and an update to apply to its latent - - # update_mlp_mus_unscaled: (batch_size, obs_size + latent_size) - update_mlp_mus_unscaled = jnp.concatenate( - (observations, prev_latents), axis=1 - ) - # update_mlp_mus: (batch_size, obs_size + latent_size, latent_size) - update_mlp_mus = ( - jnp.expand_dims(update_mlp_mus_unscaled, 2) - * self._update_mlp_multipliers + # Hidden layers will each be a single fully connected layer. + # Each layer will increase the variance of the stream, so we scale initial + # weights both by n_units_per_layer and n_layers. + hidden_w_init_scale = 1 / jnp.sqrt( + jnp.float32(n_units_per_layer * n_layers) ) - # update_mlp_sigmas: (obs_size + latent_size, latent_size) - update_mlp_sigmas = self._update_mlp_sigmas * (1 - self._eval_mode) - # update_mlp_inputs: (batch_size, obs_size + latent_size, latent_size) - update_mlp_inputs = update_mlp_mus + update_mlp_sigmas * jax.random.normal( - hk.next_rng_key(), update_mlp_mus.shape + hidden_w_init = hk.initializers.RandomNormal(stddev=hidden_w_init_scale) + self._hidden_layer_weights = [] + self._hidden_layer_biases = [] + for hidden_layer_i in range(self.n_layers): + self._hidden_layer_weights.append( + hk.get_parameter( + f'layer_{hidden_layer_i}_weights', + (self.n_units_per_layer, self.n_units_per_layer), + init=hidden_w_init, + ) + ) + self._hidden_layer_biases.append( + hk.get_parameter( + f'layer_{hidden_layer_i}_biases', + (self.n_units_per_layer,), + init=hk.initializers.Constant(0.0), + ) + ) + + # Compute sum of squares of all hidden layer weights. This will be passed on + # and can be used to compute an L2 (ridge) penalty. + self.l2 = ( + jnp.sum(jnp.square(jnp.array(self._hidden_layer_weights))) ) - # new_latents: (batch_size, latent_size) - new_latents = jnp.zeros(shape=(prev_latents.shape)) - # Loop over latents. Update each usings its own MLP - for mlp_i in jnp.arange(self._latent_size): - penalty += self._beta_scale * kl_gaussian( - update_mlp_mus[:, :, mlp_i], update_mlp_sigmas[:, mlp_i] + def __call__(self, inputs): + + # Linear projection of inputs to the size of the residual stream + stream = jnp.dot(inputs, self._input_weights) + self._input_biases + + # Each iteration adds a layer. + # Each hidden layer additively modifies the residual stream. + for hidden_layer_i in range(self.n_layers): + # (batch_size, stream_size) + layer_activations = ( + jnp.dot(stream, self._hidden_layer_weights[hidden_layer_i]) + + self._hidden_layer_biases[hidden_layer_i] ) - update_mlp_output = hk.nets.MLP( - self._update_mlp_shape, - activation=self._activation, - )(update_mlp_inputs[:, :, mlp_i]) - # update, w, new_latent: (batch_size,) - update = hk.Linear(1,)( - update_mlp_output - )[:, 0] - w = jax.nn.sigmoid(hk.Linear(1)(update_mlp_output))[:, 0] - new_latent = w * update + (1 - w) * prev_latents[:, mlp_i] - new_latents = new_latents.at[:, mlp_i].set(new_latent) - - ##################### - # Global Bottleneck # - ##################### - # noised_up_latents: (batch_size, latent_size) - noised_up_latents = new_latents + self._latent_sigmas * jax.random.normal( - hk.next_rng_key(), new_latents.shape - ) - penalty += kl_gaussian(new_latents, self._latent_sigmas) - - ############### - # CHOICE MLP # - ############### - # Predict targets for current time step - # This sees previous state but does _not_ see current observation - choice_mlp_output = hk.nets.MLP( - self._choice_mlp_shape, activation=self._activation - )(noised_up_latents) - # (batch_size, target_size) - y_hat = hk.Linear(self._target_size)(choice_mlp_output) - - # Append the penalty, so that rnn_utils can apply it as part of the loss - penalty = jnp.expand_dims(penalty, 1) # (batch_size, 1) - # If we are in eval mode, there should be no penalty - penalty = penalty * (1 - self._eval_mode) - - # output: (batch_size, target_size + 1) - output = jnp.concatenate((y_hat, penalty), axis=1) - - return output, noised_up_latents - - def initial_state(self, batch_size): - # (batch_size, latent_size) - latents = jnp.ones([batch_size, self._latent_size]) * self._latent_inits - return latents + layer_output = self.activation_fn(layer_activations) + stream += layer_output + # Linear projection to the appropriate output size + output = jnp.dot(stream, self._output_weights) + self._output_biases + + return output, self.l2 + + +def get_initial_bottleneck_params( + shape: Sequence[int], name: str, +) -> tuple[jnp.ndarray, jnp.ndarray]: + """Defines a bottleneck with a sigma and a multiplier.""" + # At init the bottlenecks should all be open: sigmas small and multipliers 1 + sigma_params = hk.get_parameter( + name + '_sigma_params', + shape, + init=hk.initializers.RandomUniform(minval=0.0, maxval=0.05), + ) + sigmas = reparameterize_sigma(sigma_params) + multipliers = hk.get_parameter( + name + '_multipliers', + shape, + init=hk.initializers.Constant(constant=1), + ) + return sigmas, multipliers -def plot_bottlenecks(params, sort_latents=True, obs_names=None): - """Plot the bottleneck sigmas from an hk.DisRNN.""" - params_disrnn = params['hk_dis_rnn'] - latent_dim = params_disrnn['latent_sigma_params'].shape[0] - obs_dim = params_disrnn['update_mlp_sigma_params'].shape[0] - latent_dim +class HkDisentangledRNN(hk.RNNCore): + """Disentangled RNN.""" - if obs_names is None: - if obs_dim == 2: - obs_names = ['Choice', 'Reward'] - elif obs_dim == 5: - obs_names = ['A', 'B', 'C', 'D', 'Reward'] - else: - obs_names = np.arange(1, obs_dim+1) + def __init__( + self, + config: DisRnnConfig, + ): + super().__init__() + self._l2_scale = config.l2_scale + self._noiseless_mode = config.noiseless_mode + + self._obs_size = config.obs_size + self._output_size = config.output_size + self._latent_size = config.latent_size + + self._update_net_n_units_per_layer = config.update_net_n_units_per_layer + self._update_net_n_layers = config.update_net_n_layers + self._choice_net_n_units_per_layer = config.choice_net_n_units_per_layer + self._choice_net_n_layers = config.choice_net_n_layers + + self._latent_penalty_scale = config.latent_penalty_scale + self._choice_net_penalty_scale = config.choice_net_penalty_scale + self._update_net_penalty_scale = config.update_net_penalty_scale + self._noiseless_mode = config.noiseless_mode + self._activation = getattr(jax.nn, config.activation) + self._max_latent_value = config.max_latent_value + + # Get Haiku parameters. IMPORTANT: if you are subclassing HkDisentangledRNN, + # you must override _get_haiku_parameters to add any new parameters that you + # need. This way of doing things is necessary for Haiku to work correctly. + self._get_haiku_parameters() + + def _get_haiku_parameters(self): + """Initializes parameters for the bottlenecks.""" + self._build_update_bottlenecks() + self._build_latent_bottlenecks() + self._build_choice_bottlenecks() + self._latent_inits = hk.get_parameter( + 'latent_inits', + (self._latent_size,), + init=hk.initializers.RandomUniform(minval=-0.1, maxval=0.1), + ) - latent_sigmas = reparameterize_sigma(params_disrnn['latent_sigma_params']) + def _build_update_bottlenecks(self): + """Initializes parameters for the update network bottlenecks.""" + # There is one Update Network per latent. Each one gets input from all + # latents and all observations. These inputs pass through bottlenecks, so + # they will need bottleneck params. + input_size = self._latent_size + self._obs_size + self._update_net_sigmas, self._update_net_multipliers = ( + get_initial_bottleneck_params( + shape=(input_size, self._latent_size), + name='update_net', + ) + ) - update_sigmas = reparameterize_sigma( - np.transpose( - params_disrnn['update_mlp_sigma_params'] - ) - ) + def _build_latent_bottlenecks(self): + """Initializes parameters for the latent bottlenecks.""" + # Latents will also go through a bottleneck after being updated. These + # bottlenecks do not need multipliers, the network output can rescale them + self._latent_sigmas, _ = ( + get_initial_bottleneck_params( + shape=(self._latent_size,), + name='latent', + ) + ) - if sort_latents: - latent_sigma_order = np.argsort( - latent_sigmas + def _build_choice_bottlenecks(self): + """Initializes parameters for the choice network bottlenecks.""" + # Choice network gets inputs from the latents, and has a bottleneck on each + self._choice_net_sigmas, self._choice_net_multipliers = ( + get_initial_bottleneck_params( + shape=(self._latent_size,), + name='choice_net', + ) ) - latent_sigmas = latent_sigmas[latent_sigma_order] - update_sigma_order = np.concatenate( - (np.arange(0, obs_dim, 1), obs_dim + latent_sigma_order), axis=0 + def initial_state(self, batch_size: Optional[int]) -> Any: + # (batch_size, latent_size) + latents = jnp.ones([batch_size, self._latent_size]) * self._latent_inits + return latents + + def update_latents( + self, update_rule_inputs: jnp.ndarray, prev_latent_values: jnp.ndarray + ) -> tuple[jnp.ndarray, jnp.ndarray]: + """Updates the latents using the update rules. + + Each latent is updated by a separate Update Network, which takes as input + all previous latents and all additional update rule inputs, and outputs an + update and a weight. New value of the latent will be the weighted average of + the previous latent and the update. + + Args: + update_rule_inputs: Additional inputs for the update rules. + prev_latent_values: The latents from the previous time step. + Returns: + new_latent_values: The updated latents. + penalty_increment: A penalty associated with the update. + """ + # penalty_increment: (batch_size,) + batch_size = prev_latent_values.shape[0] + penalty_increment = jnp.zeros(shape=(batch_size,)) + + # (batch_size, obs_size + latent_size) + update_net_inputs = jnp.concatenate( + (update_rule_inputs, prev_latent_values), axis=1 ) - update_sigmas = update_sigmas[latent_sigma_order, :] - update_sigmas = update_sigmas[:, update_sigma_order] - - latent_names = np.arange(1, latent_dim + 1) - fig = plt.subplots(1, 2, figsize=(10, 5)) - plt.subplot(1, 2, 1) - plt.imshow(np.swapaxes([1 - latent_sigmas], 0, 1), cmap='Oranges') - plt.clim(vmin=0, vmax=1) - plt.yticks(ticks=range(latent_dim), labels=latent_names) - plt.xticks(ticks=[]) - plt.ylabel('Latent #') - plt.title('Latent Bottlenecks') - - plt.subplot(1, 2, 2) - plt.imshow(1 - update_sigmas, cmap='Oranges') - plt.clim(vmin=0, vmax=1) - plt.colorbar() - plt.yticks(ticks=range(latent_dim), labels=latent_names) - xlabels = np.concatenate((np.array(obs_names), latent_names)) - plt.xticks( - ticks=range(len(xlabels)), - labels=xlabels, - rotation='vertical', - ) - plt.ylabel('Latent #') - plt.title('Update MLP Bottlenecks') - return fig + # Expand to have a separate copy per update network: + # (batch_size, obs_size + latent_size, latent_size) + update_net_inputs = jnp.tile( + jnp.expand_dims(update_net_inputs, 2), + (1, 1, self._latent_size), + ) + # Apply multipliers. We want to do this whether or not we are in noiseless + # mode, so that the model will produce similar outputs in both modes. + update_net_inputs *= self._update_net_multipliers + # Apply an information bottleneck to the inputs. If we are in noiseless + # mode, we can skip this. + if not self._noiseless_mode: + update_net_inputs, update_net_kl_cost = information_bottleneck( + mus=update_net_inputs, + sigmas=self._update_net_sigmas, + ) + penalty_increment += self._update_net_penalty_scale * update_net_kl_cost + # Loop over latents. Update each using its own network. + new_latent_values = jnp.zeros( + shape=(prev_latent_values.shape[0], self._latent_size) + ) + for net_i in jnp.arange(self._latent_size): + update_net_output, update_net_l2 = ResMLP( + input_size=update_net_inputs.shape[1], + output_size=2, + n_units_per_layer=self._update_net_n_units_per_layer, + n_layers=self._update_net_n_layers, + activation_fn=self._activation, + name='update_net', + )(update_net_inputs[:, :, net_i]) + # Add L2 to the penalty based on weights of the network + penalty_increment += self._l2_scale * update_net_l2 + # Update the latent multiplicatively, using a weight and a new target + # derived from the MLP output. + # Weight needs to be in [0, 1] + new_latent_weight = jax.nn.sigmoid(update_net_output[:, 0]) + # Target needs to be in [-max_latent_value, max_latent_value] + new_latent_target = update_net_output[:, 1] + new_latent_target = self._max_latent_value * jax.nn.tanh( + new_latent_target / self._max_latent_value + ) + # New latent value is weighted average of previous value and new target. + prev_latent_value = prev_latent_values[:, net_i] + new_latent_value = ( + 1 - new_latent_weight + ) * prev_latent_value + new_latent_weight * new_latent_target + new_latent_values = new_latent_values.at[:, net_i].set(new_latent_value) + + # Put latent values through an information bottleneck. If we are in + # noiseless mode, we can skip this. + if not self._noiseless_mode: + # new_latent_values: (batch_size, latent_size) + new_latent_values, latent_kl_cost = information_bottleneck( + mus=new_latent_values, sigmas=self._latent_sigmas + ) + penalty_increment += self._latent_penalty_scale * latent_kl_cost -def plot_update_rules(params, make_network): - """Generates visualizations of the update ruled of a HkDisRNN. - """ + return new_latent_values, penalty_increment - def step(xs, state): - core = make_network() - output, new_state = core(jnp.expand_dims(jnp.array(xs), axis=0), state) - return output, new_state + def predict_targets( + self, + new_latents: jnp.ndarray + ) -> tuple[jnp.ndarray, jnp.ndarray]: + """Predicts the targets using the choice network.""" + batch_size = new_latents.shape[0] + penalty_increment = jnp.zeros(shape=(batch_size,)) + + # Apply multipliers to the latents. We want to do this whether or not we are + # in noiseless mode, so that the model will produce similar outputs in both + # modes. + choice_net_inputs = new_latents * self._choice_net_multipliers + # Put the latents through an information bottleneck. If we are in noiseless + # mode, we can skip this and use the latents directly. + if not self._noiseless_mode: + choice_net_inputs, choice_net_kl_cost = information_bottleneck( + mus=choice_net_inputs, + sigmas=self._choice_net_sigmas + ) + penalty_increment += self._choice_net_penalty_scale * choice_net_kl_cost - _, step_hk = hk.transform(step) - key = jax.random.PRNGKey(0) - step_hk = jax.jit(step_hk) + predicted_targets, choice_net_l2 = ResMLP( + input_size=choice_net_inputs.shape[1], + output_size=self._output_size, + n_units_per_layer=self._choice_net_n_units_per_layer, + n_layers=self._choice_net_n_layers, + activation_fn=self._activation, + name='choice_net' + )(choice_net_inputs) + penalty_increment += self._l2_scale * choice_net_l2 - initial_state = np.array(rnn_utils.get_initial_state(make_network)) - reference_state = np.zeros(initial_state.shape) + return predicted_targets, penalty_increment - def plot_update_1d(params, unit_i, observations, titles): - lim = 3 - state_bins = np.linspace(-lim, lim, 20) - colormap = mpl.colormaps['viridis'].resampled(3) - colors = colormap.colors + def __call__(self, observations: jnp.ndarray, prev_latents: jnp.ndarray): + # Initial penalty values. Shape is (batch_size,) + batch_size = prev_latents.shape[0] + penalty = jnp.zeros(shape=(batch_size,)) - fig, ax = plt.subplots( - 1, len(observations), figsize=(len(observations) * 4, 5.5) + new_latents, penalty_increment = self.update_latents( + observations, prev_latents ) - plt.subplot(1, len(observations), 1) - plt.ylabel('Updated Activity') - - for observation_i, observation in enumerate(observations): - plt.subplot(1, len(observations), observation_i + 1) - - plt.plot((-3, 3), (-3, 3), '--', color='grey') - plt.plot((-3, 3), (0, 0), color='black') - plt.plot((0, 0), (-3, 3), color='black') - - delta_states = np.zeros(shape=(len(state_bins), 1)) - for s_i in np.arange(len(state_bins)): - state = reference_state - state[0, unit_i] = state_bins[s_i] - _, next_state = step_hk( - params, key, observation, state - ) - delta_states[s_i] = np.array(next_state[0, unit_i]) - plt.plot(state_bins, delta_states, color=colors[1]) + penalty += penalty_increment + predicted_targets, penalty_increment = self.predict_targets(new_latents) + penalty += penalty_increment - plt.title(titles[observation_i]) - plt.xlim(-lim, lim) - plt.ylim(-lim, lim) - plt.xlabel('Previous Activity') + # Output has shape (batch_size, output_size + 1). + # The first output_size elements are the predicted targets, and the last + # element is the penalty. We preassign instead of using concatenate to avoid + # errors caused by silent broadcasting. + output_shape = (batch_size, self._output_size + 1) + output = jnp.zeros(output_shape) + output = output.at[:, :-1].set(predicted_targets) + output = output.at[:, -1].set(penalty) - if isinstance(ax, np.ndarray): - ax[observation_i].set_aspect('equal') - else: - ax.set_aspect('equal') - return fig + return output, new_latents - def plot_update_2d(params, unit_i, unit_input, observations, titles): - lim = 3 - state_bins = np.linspace(-lim, lim, 20) - colormap = mpl.colormaps['viridis'].resampled(len(state_bins)) - colors = colormap.colors +def log_bottlenecks(params, + open_thresh=0.1, + partially_open_thresh=0.25, + closed_thresh=0.9) -> dict[str, int]: + """Computes info about bottlenecks.""" - fig, ax = plt.subplots( - 1, len(observations), figsize=(len(observations) * 2 + 10, 5.5) - ) - plt.subplot(1, len(observations), 1) - plt.ylabel('Updated Latent ' + str(unit_i + 1) + ' Activity') - - for observation_i, observation in enumerate(observations): - plt.subplot(1, len(observations), observation_i + 1) - - plt.plot((-3, 3), (-3, 3), '--', color='grey') - plt.plot((-3, 3), (0, 0), color='black') - plt.plot((0, 0), (-3, 3), color='black') - - for si_i in np.arange(len(state_bins)): - delta_states = np.zeros(shape=(len(state_bins), 1)) - for s_i in np.arange(len(state_bins)): - state = reference_state - state[0, unit_i] = state_bins[s_i] - state[0, unit_input] = state_bins[si_i] - _, next_state = step_hk(params, key, observation, state) - delta_states[s_i] = np.array(next_state[0, unit_i]) - - plt.plot(state_bins, delta_states, color=colors[si_i]) - - plt.title(titles[observation_i]) - plt.xlim(-lim, lim) - plt.ylim(-lim, lim) - plt.xlabel('Latent ' + str(unit_i + 1) + ' Activity') - - if isinstance(ax, np.ndarray): - ax[observation_i].set_aspect('equal') - else: - ax.set_aspect('equal') - return fig - - latent_sigmas = reparameterize_sigma( - params['hk_dis_rnn']['latent_sigma_params'] + params_disrnn = params['hk_disentangled_rnn'] + + latent_sigmas = np.array( + reparameterize_sigma(params_disrnn['latent_sigma_params']) ) - update_sigmas = reparameterize_sigma( - np.transpose( - params['hk_dis_rnn']['update_mlp_sigma_params'] - ) + update_sigmas = np.array( + reparameterize_sigma( + np.transpose(params_disrnn['update_net_sigma_params']) ) - latent_order = np.argsort(latent_sigmas) - figs = [] - - # Loop over latents. Plot update rules - for latent_i in latent_order: - # If this latent's bottleneck is open - if latent_sigmas[latent_i] < 0.5: - # Which of its input bottlenecks are open? - update_mlp_inputs = np.argwhere(update_sigmas[latent_i] < 0.9) - choice_sensitive = np.any(update_mlp_inputs == 0) - reward_sensitive = np.any(update_mlp_inputs == 1) - # Choose which observations to use based on input bottlenecks - if choice_sensitive and reward_sensitive: - observations = ([0, 0], [0, 1], [1, 0], [1, 1]) - titles = ('Left, Unrewarded', - 'Left, Rewarded', - 'Right, Unrewarded', - 'Right, Rewarded') - elif choice_sensitive: - observations = ([0, 0], [1, 0]) - titles = ('Choose Left', 'Choose Right') - elif reward_sensitive: - observations = ([0, 0], [0, 1]) - titles = ('Rewarded', 'Unreward') - else: - observations = ([0, 0],) - titles = ('All Trials',) - # Choose whether to condition on other latent values - latent_sensitive = update_mlp_inputs[update_mlp_inputs > 1] - 2 - # Doesn't count if it depends on itself (this'll be shown no matter what) - latent_sensitive = np.delete( - latent_sensitive, latent_sensitive == latent_i + ) + choice_sigmas = np.array( + reparameterize_sigma( + np.transpose(params_disrnn['choice_net_sigma_params']) ) - if not latent_sensitive.size: # Depends on no other latents - fig = plot_update_1d(params, latent_i, observations, titles) - else: # It depends on latents other than itself. - fig = plot_update_2d( - params, - latent_i, - latent_sensitive[np.argmax(latent_sensitive)], - observations, - titles, - ) - if len(latent_sensitive) > 1: - print( - 'WARNING: This update rule depends on more than one ' - + 'other latent. Plotting just one of them' - ) - - figs.append(fig) + ) - return figs + latent_bottlenecks_open = np.sum(latent_sigmas < open_thresh) + choice_bottlenecks_open = np.sum(choice_sigmas < open_thresh) + update_bottlenecks_open = np.sum(update_sigmas < open_thresh) + + latent_bottlenecks_partial = np.sum(latent_sigmas < partially_open_thresh) + choice_bottlenecks_partial = np.sum(choice_sigmas < partially_open_thresh) + update_bottlenecks_partial = np.sum(update_sigmas < partially_open_thresh) + + latent_bottlenecks_closed = np.sum(latent_sigmas > closed_thresh) + choice_bottlenecks_closed = np.sum(choice_sigmas > closed_thresh) + update_bottlenecks_closed = np.sum(update_sigmas > closed_thresh) + + bottleneck_dict = { + 'latent_bottlenecks_open': int(latent_bottlenecks_open), + 'latent_bottlenecks_partial': int(latent_bottlenecks_partial), + 'latent_bottlenecks_closed': int(latent_bottlenecks_closed), + 'choice_bottlenecks_open': int(choice_bottlenecks_open), + 'choice_bottlenecks_partial': int(choice_bottlenecks_partial), + 'choice_bottlenecks_closed': int(choice_bottlenecks_closed), + 'update_bottlenecks_open': int(update_bottlenecks_open), + 'update_bottlenecks_partial': int(update_bottlenecks_partial), + 'update_bottlenecks_closed': int(update_bottlenecks_closed), + } + return bottleneck_dict + + +def get_total_sigma(params): + """Get sum of reparameterized sigmas of a DisRNN.""" + + params_disrnn = params['hk_disentangled_rnn'] + + latent_bottlenecks = reparameterize_sigma( + params_disrnn['latent_sigma_params']) + update_bottlenecks = reparameterize_sigma( + params_disrnn['update_net_sigma_params']) + choice_bottlenecks = reparameterize_sigma( + params_disrnn['choice_net_sigma_params']) + + return float( + jnp.sum(latent_bottlenecks) + + jnp.sum(update_bottlenecks) + + jnp.sum(choice_bottlenecks) + ) diff --git a/disentangled_rnns/library/disrnn_test.py b/disentangled_rnns/library/disrnn_test.py index 2d32605..cc26cbb 100644 --- a/disentangled_rnns/library/disrnn_test.py +++ b/disentangled_rnns/library/disrnn_test.py @@ -15,53 +15,108 @@ from absl.testing import absltest from disentangled_rnns.library import disrnn from disentangled_rnns.library import get_datasets +from disentangled_rnns.library import plotting from disentangled_rnns.library import rnn_utils -import optax - class DisrnnTest(absltest.TestCase): - def test_training_and_eval(self): - """Test that training and eval work.""" - dataset = get_datasets.get_q_learning_dataset(n_trials=10, n_sessions=10) + def setUp(self): + super().setUp() + self.disrnn_config = disrnn.DisRnnConfig( + latent_size=5, + obs_size=2, + output_size=2, + update_net_n_units_per_layer=4, + update_net_n_layers=2, + choice_net_n_units_per_layer=2, + choice_net_n_layers=2, + ) + self.q_dataset = get_datasets.get_q_learning_dataset( + n_sessions=11, n_trials=7 + ) + self.disrnn_params, _, _ = rnn_utils.train_network( + make_network=lambda: disrnn.HkDisentangledRNN(self.disrnn_config), + training_dataset=self.q_dataset, + validation_dataset=None, + n_steps=0, + ) + + def test_disrnn_params(self): + """Check that disRNN params are as expected.""" + disrnn_config = self.disrnn_config + disrnn_params = self.disrnn_params - # Train for a few steps - params, opt_state, losses = rnn_utils.train_network( - training_dataset=dataset, - validation_dataset=dataset, - make_network=disrnn.HkDisRNN, - opt=optax.adam(1e-3), - n_steps=20, - loss="penalized_categorical", - params=None, - opt_state=None, - loss_param=1e-3, - do_plot=True, + self.assertIn('hk_disentangled_rnn', disrnn_params) + self.assertIn( + 'hk_disentangled_rnn/~update_latents/update_net', + disrnn_params, + ) + self.assertIn( + 'hk_disentangled_rnn/~predict_targets/choice_net', disrnn_params ) - loss_init = losses["training_loss"][-1] - # Train for a few more steps - params, _, losses = rnn_utils.train_network( - training_dataset=dataset, - validation_dataset=dataset, - make_network=disrnn.HkDisRNN, - opt=optax.adam(1e-3), - n_steps=20, - loss="penalized_categorical", - params=params, - opt_state=opt_state, - loss_param=1e-3, - do_plot=True, + params = disrnn_params['hk_disentangled_rnn'] + update_net_params = disrnn_params[ + 'hk_disentangled_rnn/~update_latents/update_net' + ] + choice_net_params = disrnn_params[ + 'hk_disentangled_rnn/~predict_targets/choice_net' + ] + + self.assertIn('update_net_sigma_params', params) + self.assertIn('update_net_multipliers', params) + self.assertIn('latent_sigma_params', params) + self.assertIn('choice_net_sigma_params', params) + self.assertIn('choice_net_multipliers', params) + self.assertIn('latent_inits', params) + + # Check shapes based on config + latent_size = disrnn_config.latent_size + obs_size = disrnn_config.obs_size + net_input_size = latent_size + obs_size + + self.assertEqual( + params['update_net_sigma_params'].shape, (net_input_size, latent_size) + ) + self.assertEqual( + params['update_net_multipliers'].shape, (net_input_size, latent_size) + ) + self.assertEqual(params['latent_sigma_params'].shape, (latent_size,)) + self.assertEqual(params['choice_net_sigma_params'].shape, (latent_size,)) + self.assertEqual(params['choice_net_multipliers'].shape, (latent_size,)) + self.assertEqual(params['latent_inits'].shape, (latent_size,)) + self.assertEqual( + update_net_params['input_weights'].shape, + (net_input_size, disrnn_config.update_net_n_units_per_layer), + ) + self.assertEqual( + choice_net_params['input_weights'].shape, + (latent_size, disrnn_config.choice_net_n_units_per_layer), ) - loss_final = losses["training_loss"][-1] - # Check that loss has decreased - self.assertLess(loss_final, loss_init) + def test_disrnn_plotting(self): + plotting.plot_bottlenecks(self.disrnn_params, self.disrnn_config) + plotting.plot_update_rules(self.disrnn_params, self.disrnn_config) + plotting.plot_choice_rule(self.disrnn_params, self.disrnn_config) + + def test_disrnn_output_shape(self): + xs, _ = self.q_dataset.get_all() + n_sessions, n_trials = xs.shape[:2] + + network_outputs, network_states = rnn_utils.eval_network( + lambda: disrnn.HkDisentangledRNN(self.disrnn_config), + self.disrnn_params, + xs, + ) + # Output has shape (batch_size, output_size + 1). + # The first output_size elements are the predicted targets, andt the last + # element is the penalty + self.assertEqual(network_outputs.shape, (n_sessions, n_trials, 3)) + self.assertEqual(network_states.shape, (n_sessions, + n_trials, + self.disrnn_config.latent_size)) - # Check that plotting functions work - disrnn.plot_bottlenecks(params) - disrnn.plot_update_rules(params, disrnn.HkDisRNN) -if __name__ == "__main__": +if __name__ == '__main__': absltest.main() diff --git a/disentangled_rnns/library/example_test.py b/disentangled_rnns/library/example_test.py index 7f32b89..aedce8f 100644 --- a/disentangled_rnns/library/example_test.py +++ b/disentangled_rnns/library/example_test.py @@ -27,8 +27,8 @@ def test_example_script(self): try: FLAGS.n_steps_per_session = 10 FLAGS.n_sessions = 10 - FLAGS.n_training_steps = 20 - FLAGS.n_warmup_steps = 20 + FLAGS.n_training_steps = 10 + FLAGS.n_warmup_steps = 10 example.main(None) except Exception as e: # pylint: disable=broad-exception-caught self.fail(f"Example script failed: {e}") diff --git a/disentangled_rnns/library/plotting.py b/disentangled_rnns/library/plotting.py new file mode 100644 index 0000000..6aa8975 --- /dev/null +++ b/disentangled_rnns/library/plotting.py @@ -0,0 +1,405 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plotting functions for inspecting Disentangled RNNs.""" + +import copy +from typing import Optional + +from disentangled_rnns.library import disrnn +from disentangled_rnns.library import rnn_utils +import haiku as hk +import jax +import jax.numpy as jnp +import matplotlib as mpl +from matplotlib import pyplot as plt +import numpy as np + + +def plot_bottlenecks( + params: hk.Params, + disrnn_config: disrnn.DisRnnConfig, + sort_latents: bool = True, +) -> plt.Figure: + """Plot the bottleneck sigmas from an hk.DisentangledRNN.""" + + params_disrnn = params['hk_disentangled_rnn'] + + latent_dim = params_disrnn['latent_sigma_params'].shape[0] + obs_dim = params_disrnn['update_net_sigma_params'].shape[0] - latent_dim + + update_input_names = disrnn_config.x_names + + latent_sigmas = np.array( + disrnn.reparameterize_sigma(params_disrnn['latent_sigma_params']) + ) + update_sigmas = np.array( + disrnn.reparameterize_sigma( + np.transpose(params_disrnn['update_net_sigma_params']) + ) + ) + choice_sigmas = np.array( + disrnn.reparameterize_sigma( + np.transpose(params_disrnn['choice_net_sigma_params']) + ) + ) + + if sort_latents: + latent_sigma_order = np.argsort(latent_sigmas) + latent_sigmas = latent_sigmas[latent_sigma_order] + choice_sigmas = choice_sigmas[latent_sigma_order] + + update_sigma_order = np.concatenate( + (np.arange(0, obs_dim, 1), obs_dim + latent_sigma_order), axis=0 + ) + update_sigmas = update_sigmas[latent_sigma_order, :] + update_sigmas = update_sigmas[:, update_sigma_order] + + latent_names = np.arange(1, latent_dim + 1) + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + + # Plot Latent Bottlenecks on axes[0] + im1 = axes[0].imshow(np.swapaxes([1 - latent_sigmas], 0, 1), cmap='Oranges') + im1.set_clim(vmin=0, vmax=1) + axes[0].set_yticks(ticks=range(latent_dim), labels=latent_names) + axes[0].set_xticks(ticks=[]) + axes[0].set_ylabel('Latent #') + axes[0].set_title('Latent Bottlenecks') + + # Plot Choice Bottlenecks on axes[1] + im2 = axes[1].imshow(np.swapaxes([1 - choice_sigmas], 0, 1), cmap='Oranges') + im2.set_clim(vmin=0, vmax=1) + axes[1].set_yticks(ticks=range(latent_dim), labels=latent_names) + axes[1].set_xticks(ticks=[]) + axes[1].set_ylabel('Latent #') + axes[1].set_title('Choice Network Bottlenecks') + + # Plot Update Bottlenecks on axes[2] + im3 = axes[2].imshow(1 - update_sigmas, cmap='Oranges') + im3.set_clim(vmin=0, vmax=1) + fig.colorbar(im3, ax=axes[2]) + axes[2].set_yticks(ticks=range(latent_dim), labels=latent_names) + xlabels = np.concatenate((np.array(update_input_names), latent_names)) + axes[2].set_xticks( + ticks=range(len(xlabels)), + labels=xlabels, + rotation='vertical', + ) + axes[2].set_ylabel('Latent #') + axes[2].set_title('Update Network Bottlenecks') + fig.tight_layout() # Adjust layout to prevent overlap + return fig + + +def plot_update_rules( + params: hk.Params, + disrnn_config: disrnn.DisRnnConfig, + subj_ind: Optional[int] = None, + axis_lim: float = 2.1, +) -> plt.Figure: + """Generates visualizations of the update rules of a HkDisentangledRNN.""" + + disrnn_config = copy.deepcopy(disrnn_config) + disrnn_config.noiseless_mode = True # Turn off noise for plotting + + + make_network = lambda: disrnn.HkDisentangledRNN(disrnn_config) + obs_names = disrnn_config.x_names + param_prefix = 'hk_disentangled_rnn' + + def step(xs, state): + core = make_network() + output, new_state = core(jnp.expand_dims(jnp.array(xs), axis=0), state) + return output, new_state + + _, step_hk = hk.transform(step) + key = jax.random.PRNGKey(0) + step_hk = jax.jit(step_hk) + + initial_state = np.array(rnn_utils.get_initial_state(make_network)) + reference_state = np.zeros(initial_state.shape) + + def plot_update_1d(params, unit_i, observations, titles): + state_bins = np.linspace(-axis_lim, axis_lim, 20) + colormap = mpl.colormaps['viridis'].resampled(3) + colors = colormap.colors + + fig, axes = plt.subplots( + 1, len(observations), figsize=(len(observations) * 4, 5.5), sharey=True + ) + # Ensure axes is always an array for consistent indexing + if len(observations) == 1: + axes = [axes] + axes[0].set_ylabel('Δ Activity') + + for observation_i in range(len(observations)): + observation = observations[observation_i] + if subj_ind is not None: + observation = [subj_ind] + observation + ax = axes[observation_i] + delta_states = np.zeros(shape=(len(state_bins), 1)) + for s_i in np.arange(len(state_bins)): + state = reference_state + state[0, unit_i] = state_bins[s_i] + _, next_state = step_hk( + params, key, observation, state + ) + next_state = np.array(next_state) + delta_states[s_i] = next_state[0, unit_i] - state_bins[s_i] + + ax.plot((-axis_lim, axis_lim), (0, 0), color='black') + ax.plot(state_bins, delta_states, color=colors[1]) + ax.set_title(titles[observation_i]) + ax.set_xlim(-axis_lim, axis_lim) + ax.set_xlabel('Latent ' + str(unit_i + 1) + ' Activity') + ax.set_aspect('equal') + + return fig + + def plot_update_2d(params, unit_i, unit_input, observations, titles): + + state_bins = np.linspace(-axis_lim, axis_lim, 50) + state_bins_input = np.linspace(-axis_lim/2, axis_lim/2, 5) + colormap = mpl.colormaps['viridis'].resampled(len(state_bins_input)) + colors = colormap.colors + + fig, axes = plt.subplots( + 1, + len(observations), + figsize=(len(observations) * 2 + 10, 5.5), + sharey=True, + ) + # Ensure axes is always an array for consistent indexing + if len(observations) == 1: + axes = [axes] + axes[0].set_ylabel('Δ Activity') + + for observation_i in range(len(observations)): + observation = observations[observation_i] + if subj_ind is not None: + observation = [subj_ind] + observation + legend_elements = [] + ax = axes[observation_i] + for si_i in np.arange(len(state_bins_input)): + delta_states = np.zeros(shape=(len(state_bins), 1)) + for s_i in np.arange(len(state_bins)): + state = reference_state + state[0, unit_i] = state_bins[s_i] + state[0, unit_input] = state_bins_input[si_i] + + _, next_state = step_hk(params, key, observation, state) + next_state = np.array(next_state) + delta_states[s_i] = next_state[0, unit_i] - state_bins[s_i] + + lines = ax.plot(state_bins, delta_states, color=colors[si_i]) + legend_elements.append(lines[0]) + + if observation_i == 0: + legend_labels = [f'{num:.1f}' for num in state_bins_input] # pylint: disable=bad-whitespace + ax.legend(legend_elements, legend_labels) + + ax.plot((-axis_lim, axis_lim), (0, 0), color='black') + ax.set_title(titles[observation_i]) + ax.set_xlim(-axis_lim, axis_lim) + ax.set_xlabel('Latent ' + str(unit_i + 1) + ' Activity') + + return fig + + latent_sigmas = np.array( + disrnn.reparameterize_sigma( + params[param_prefix]['latent_sigma_params'] + ) + ) + update_sigmas = np.array( + disrnn.reparameterize_sigma( + np.transpose(params[param_prefix]['update_net_sigma_params']) + ) + ) + + obs_size = 2 + + latent_order = np.argsort(latent_sigmas) + figs = [] + + # Loop over latents. Plot update rules + for latent_i in latent_order: + # If this latent's bottleneck is open + if latent_sigmas[latent_i] < 0.5: + + # Which of its input bottlenecks are open? + update_net_inputs = np.argwhere(update_sigmas[latent_i] < 0.5) + # TODO(kevinjmiller): Generalize to allow different observation length + obs1_sensitive = np.any(update_net_inputs == 0) + obs2_sensitive = np.any(update_net_inputs == 1) + # Choose which observations to use based on input bottlenecks + if obs1_sensitive and obs2_sensitive: + observations = ([0, 0], [0, 1], [1, 0], [1, 1]) + titles = ( + obs_names[0] + ': 0\n' + obs_names[1] + ': 0', + obs_names[0] + ': 0\n' + obs_names[1] + ': 1', + obs_names[0] + ': 1\n' + obs_names[1] + ': 0', + obs_names[0] + ': 1\n' + obs_names[1] + ': 1', + ) + elif obs1_sensitive: + observations = ([0, 0], [1, 0]) + titles = (obs_names[0] + ': 0', obs_names[0] + ': 1') + elif obs2_sensitive: + observations = ([0, 0], [0, 1]) + titles = (obs_names[1] + ': 0', obs_names[1] + ': 1') + else: + observations = ([0, 0],) + titles = ('All Trials',) + + # Choose whether to condition on other latent values + update_net_input_latents = ( + update_net_inputs[obs_size:, 0] + - (obs_size) + ) + # Doesn't count if it depends on itself (this'll be shown no matter what) + latent_sensitive = np.delete( + update_net_input_latents, update_net_input_latents == latent_i + ) + if not latent_sensitive.size: # Depends on no other latents + fig = plot_update_1d(params, latent_i, observations, titles) + else: # It depends on latents other than itself. + fig = plot_update_2d( + params, + latent_i, + latent_sensitive[0], + observations, + titles, + ) + if len(latent_sensitive) > 1: + print( + 'WARNING: This update rule depends on more than one ' + 'other latent. Plotting just one of them' + ) + figs.append(fig) + fig.tight_layout() + + return figs + + +def plot_choice_rule( + params: hk.Params, + disrnn_config: disrnn.DisRnnConfig, + axis_lim: float = 2.1, +) -> Optional[plt.Figure]: + """Plots the choice rule of a DisRNN. + + Args: + params: The parameters of the DisRNN + disrnn_config: A DisRnnConfig object + axis_lim: The axis limit for the plot. + + Returns: + A matplotlib Figure object, or None if choice depends on no latents. + """ + + disrnn_config = copy.deepcopy(disrnn_config) + disrnn_config.noiseless_mode = True # Turn off noise for plotting + activation_fn = getattr(jax.nn, disrnn_config.activation) + + params_prefix = 'hk_disentangled_rnn' + + n_vals = 100 + + def forward(xs): + choice_net_output = disrnn.ResMLP( + input_size=disrnn_config.latent_size, + output_size=disrnn_config.output_size, + n_units_per_layer=disrnn_config.choice_net_n_units_per_layer, + n_layers=disrnn_config.choice_net_n_layers, + activation_fn=activation_fn, + name='choice_net', + )(xs) + return choice_net_output + + model = hk.transform(forward) + apply = jax.jit(model.apply) + + choice_net_params = { + 'choice_net': params[params_prefix + '/~predict_targets/choice_net'] + } + choice_net_sigmas = disrnn.reparameterize_sigma( + params[params_prefix]['choice_net_sigma_params'] + ) + n_inputs = np.sum(choice_net_sigmas < 0.5) + choice_net_input_order = np.argsort(choice_net_sigmas) + + if n_inputs == 0: + print('Choice does not depend on any latents') + return None + elif n_inputs == 1: + # Choice Rule 1D: A curve + policy_latent_ind = choice_net_input_order[0] + policy_latent_vals = np.linspace(-axis_lim, axis_lim, n_vals) + xs = np.zeros(( + n_vals, + disrnn_config.latent_size, + )) + xs[:, policy_latent_ind] = policy_latent_vals + y_hats, _ = apply(choice_net_params, jax.random.PRNGKey(0), xs) + choice_logits = y_hats[:, 0] - y_hats[:, 1] + + fig, ax = plt.subplots() + ax.plot(policy_latent_vals, choice_logits, 'g') + ax.set_title('Choice Rule') + ax.set_xlabel(f'Latent {policy_latent_ind + 1}') + ax.set_ylabel('Choice Logit') + + else: + # Choice Rule 2D: A colormap + if n_inputs > 2: + print( + 'WARNING: More than two latents contribute to choice. Plotting only', + ' the first two.' + ) + + policy_latent_inds = choice_net_input_order[:2] + + latent_vals = np.linspace(-axis_lim, axis_lim, n_vals) + + xv, yv = np.meshgrid(latent_vals, latent_vals) + latent0_vals = np.reshape(xv, (xv.size,)) + latent1_vals = np.reshape(yv, (yv.size,)) + + xs = np.zeros( + shape=( + n_vals**2, + disrnn_config.latent_size, + ) + ) + xs[:, policy_latent_inds[0]] = latent0_vals + xs[:, policy_latent_inds[1]] = latent1_vals + + y_hats, _ = apply(choice_net_params, jax.random.PRNGKey(0), xs) + # TODO(kevinjmiller): This assumes two-alternative logits. Generalize to + # allow more alternatives and/or scalar outputs + choice_logits = y_hats[:, 1] - y_hats[:, 0] + + cmax = np.max(np.abs(choice_logits)) + + fig, ax = plt.subplots() + scatter = ax.scatter( + latent0_vals, latent1_vals, c=choice_logits, s=100, cmap='bwr' + ) + scatter.set_clim(-cmax, cmax) + cbar = fig.colorbar(scatter, ax=ax) + ax.set_title('Choice Rule') + ax.set_xlabel(f'Latent {policy_latent_inds[0]+1}') + ax.set_ylabel(f'Latent {policy_latent_inds[1]+1}') + cbar.set_label('Choice Logit') + + return fig diff --git a/disentangled_rnns/notebooks/train_single_gru.ipynb b/disentangled_rnns/notebooks/train_single_gru.ipynb new file mode 100644 index 0000000..122ecc8 --- /dev/null +++ b/disentangled_rnns/notebooks/train_single_gru.ipynb @@ -0,0 +1,205 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZST5GqoRtfZz" + }, + "outputs": [], + "source": [ + "import optax\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl\n", + "import haiku as hk\n", + "\n", + "from disentangled_rnns.library import rnn_utils\n", + "from disentangled_rnns.library import get_datasets\n", + "\n", + "# Setup so that plots will look nice\n", + "small = 15\n", + "medium = 18\n", + "large = 20\n", + "plt.rc('axes', titlesize=large)\n", + "plt.rc('axes', labelsize=medium)\n", + "plt.rc('xtick', labelsize=small)\n", + "plt.rc('ytick', labelsize=small)\n", + "plt.rc('legend', fontsize=small)\n", + "plt.rc('figure', titlesize=large)\n", + "mpl.rcParams['grid.color'] = 'none'\n", + "mpl.rcParams['axes.facecolor'] = 'white'\n", + "plt.rcParams['svg.fonttype'] = 'none'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ApA1YfVGz9Uq" + }, + "source": [ + "# Define a dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MIhLKbgHPYmQ" + }, + "outputs": [], + "source": [ + "dataset = get_datasets.get_q_learning_dataset(n_sessions=500, n_trials=200)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "caSlZS4OR0PK" + }, + "outputs": [], + "source": [ + "dataset_train, dataset_eval = rnn_utils.split_dataset(dataset, eval_every_n=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ONzEfURn0DU4" + }, + "source": [ + "# Define and train RNN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "h5lNm21PRJti" + }, + "outputs": [], + "source": [ + "# Define the architecture of the network we'd like to train\n", + "n_hidden = 16\n", + "output_size = 2\n", + "\n", + "def make_network():\n", + " model = hk.DeepRNN(\n", + " [hk.GRU(n_hidden), hk.Linear(output_size=output_size)]\n", + " )\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OULn6VOf0l-R" + }, + "outputs": [], + "source": [ + "# INITIALIZE THE NETWORK\n", + "# Running rnn_utils.train_network with n_steps=0 does no training but sets up the\n", + "# parameters and optimizer state.\n", + "params, opt_state, losses = rnn_utils.train_network(\n", + " make_network = make_network,\n", + " training_dataset=dataset_train,\n", + " validation_dataset=dataset_eval,\n", + " opt = optax.adam(1e-2),\n", + " loss=\"categorical\",\n", + " n_steps=0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JwFLIG_U1Eli" + }, + "outputs": [], + "source": [ + "# TRAIN THE NETWORK\n", + "# Running this cell repeatedly continues to train the same network.\n", + "# The cell below gives insight into what's going on in your network.\n", + "# If you'd like to reinitialize the network and start over, re-run the above cell\n", + "\n", + "n_steps = 1000\n", + "\n", + "params, opt_state, losses = rnn_utils.train_network(\n", + " make_network = make_network,\n", + " training_dataset=dataset_train,\n", + " validation_dataset=dataset_eval,\n", + " loss=\"categorical\",\n", + " params=params,\n", + " opt_state=opt_state,\n", + " opt = optax.adam(1e-3),\n", + " loss_param = 1,\n", + " n_steps=n_steps,\n", + " do_plot = True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oiPRjxjQSFLH" + }, + "outputs": [], + "source": [ + "# Run forward pass on the unseen data\n", + "xs_eval, ys_eval = dataset_eval.get_all()\n", + "network_output, network_states = rnn_utils.eval_network(make_network, params, xs_eval)\n", + "\n", + "# Compute normalized likelihood\n", + "score = rnn_utils.normalized_likelihood(ys_eval, network_output)\n", + "print(f'Normalized Likelihood: {100*score:.1f}%')\n", + "\n", + "# Plot network activations on an example session\n", + "example_session = 0\n", + "plt.plot(network_states[:,example_session,:])\n", + "plt.xlabel('Trial Number')\n", + "plt.ylabel('Network Activations')" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//learning/deepmind/dm_python:dm_notebook3", + "kind": "private" + }, + "private_outputs": true, + "provenance": [ + { + "file_id": "1tbH1PMKB0rz4ajkRQlzBV7iYCll5RlhS", + "timestamp": 1746631254097 + }, + { + "file_id": "/piper/depot/google3/learning/deepmind/research/neuroexp/disrnn/notebooks/train_single_disrnn.ipynb?workspaceId=kevinjmiller:disentangled_rnns::citc", + "timestamp": 1746630089612 + }, + { + "file_id": "1b5VOqHaVDOJ3fAW2E853NBQbSu2Yi-CP", + "timestamp": 1727798409618 + }, + { + "file_id": "1xgFbsQ34Of-WBTEQM_Hf7Di7N9YpRmdR", + "timestamp": 1726760254895 + }, + { + "file_id": "1IuwwEfCic7w3NsyVoVPtZSQCzrvTgh_X", + "timestamp": 1696507812638 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}