Skip to content

Add example colab for training GRU #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 25 additions & 30 deletions disentangled_rnns/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
"""Example DisRNN workflow: Define a dataset, train network, inspect the fit.
"""

import copy

from absl import app
from absl import flags
from disentangled_rnns.library import disrnn
from disentangled_rnns.library import plotting
from disentangled_rnns.library import rnn_utils
from disentangled_rnns.library import two_armed_bandits
import optax
Expand Down Expand Up @@ -64,18 +67,18 @@ def main(_) -> None:
)

# Define the disRNN architecture
update_mlp_shape = (3, 5, 5)
choice_mlp_shape = (2, 2)
latent_size = 5

def make_network():
return disrnn.HkDisRNN(
update_mlp_shape=update_mlp_shape,
choice_mlp_shape=choice_mlp_shape,
latent_size=latent_size,
obs_size=2,
target_size=2,
)
disrnn_config = disrnn.DisRnnConfig(
obs_size=2,
output_size=2,
latent_size=5,
update_net_n_units_per_layer=8,
update_net_n_layers=4,
choice_net_n_units_per_layer=4,
choice_net_n_layers=2,
x_names=dataset.x_names,
y_names=dataset.y_names,
)
make_disrnn = lambda: disrnn.HkDisentangledRNN(disrnn_config)

opt = optax.adam(learning_rate=FLAGS.learning_rate)

Expand All @@ -85,7 +88,7 @@ def make_network():

# Warmup training with no information penalty
params, _, _ = rnn_utils.train_network(
make_network,
make_disrnn,
training_dataset=dataset,
validation_dataset=dataset_eval,
loss="penalized_categorical",
Expand All @@ -99,7 +102,7 @@ def make_network():

# Additional training using information penalty
params, _, _ = rnn_utils.train_network(
make_network,
make_disrnn,
training_dataset=dataset,
validation_dataset=dataset_eval,
loss="penalized_categorical",
Expand All @@ -114,28 +117,20 @@ def make_network():
###########################
# Inspecting a fit disRNN #
###########################

# Eval mode runs the network with no noise
def make_network_eval():
return disrnn.HkDisRNN(
update_mlp_shape=update_mlp_shape,
choice_mlp_shape=choice_mlp_shape,
latent_size=latent_size,
obs_size=2,
target_size=2,
eval_mode=True,
)

disrnn.plot_bottlenecks(params, make_network_eval)
disrnn.plot_update_rules(params, make_network_eval)
# Plot bottleneck structure and update rules
plotting.plot_bottlenecks(params, disrnn_config)
plotting.plot_update_rules(params, disrnn_config)

##############################
# Eval disRNN on unseen data #
##############################

# Run the network in noiseless mode to see evolution of states over time
config_noiseless = copy.deepcopy(disrnn_config)
config_noiseless.noiseless_mode = True
make_noiseless_disrnn = lambda: disrnn.HkDisentangledRNN(config_noiseless)
xs, _ = next(dataset_eval)
# pylint: disable-next=unused-variable
_, network_states = rnn_utils.eval_network(make_network_eval, params, xs)
_, network_states = rnn_utils.eval_network(make_noiseless_disrnn, params, xs)


if __name__ == "__main__":
Expand Down
Loading
Loading