Skip to content

Commit 6116409

Browse files
kevin-j-millercopybara-github
authored andcommitted
Add example colab for training GRU
PiperOrigin-RevId: 756288921
1 parent 594ca0e commit 6116409

File tree

6 files changed

+1213
-405
lines changed

6 files changed

+1213
-405
lines changed

disentangled_rnns/example.py

+25-30
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
"""Example DisRNN workflow: Define a dataset, train network, inspect the fit.
1616
"""
1717

18+
import copy
19+
1820
from absl import app
1921
from absl import flags
2022
from disentangled_rnns.library import disrnn
23+
from disentangled_rnns.library import plotting
2124
from disentangled_rnns.library import rnn_utils
2225
from disentangled_rnns.library import two_armed_bandits
2326
import optax
@@ -64,18 +67,18 @@ def main(_) -> None:
6467
)
6568

6669
# Define the disRNN architecture
67-
update_mlp_shape = (3, 5, 5)
68-
choice_mlp_shape = (2, 2)
69-
latent_size = 5
70-
71-
def make_network():
72-
return disrnn.HkDisRNN(
73-
update_mlp_shape=update_mlp_shape,
74-
choice_mlp_shape=choice_mlp_shape,
75-
latent_size=latent_size,
76-
obs_size=2,
77-
target_size=2,
78-
)
70+
disrnn_config = disrnn.DisRnnConfig(
71+
obs_size=2,
72+
output_size=2,
73+
latent_size=5,
74+
update_net_n_units_per_layer=8,
75+
update_net_n_layers=4,
76+
choice_net_n_units_per_layer=4,
77+
choice_net_n_layers=2,
78+
x_names=dataset.x_names,
79+
y_names=dataset.y_names,
80+
)
81+
make_disrnn = lambda: disrnn.HkDisentangledRNN(disrnn_config)
7982

8083
opt = optax.adam(learning_rate=FLAGS.learning_rate)
8184

@@ -85,7 +88,7 @@ def make_network():
8588

8689
# Warmup training with no information penalty
8790
params, _, _ = rnn_utils.train_network(
88-
make_network,
91+
make_disrnn,
8992
training_dataset=dataset,
9093
validation_dataset=dataset_eval,
9194
loss="penalized_categorical",
@@ -99,7 +102,7 @@ def make_network():
99102

100103
# Additional training using information penalty
101104
params, _, _ = rnn_utils.train_network(
102-
make_network,
105+
make_disrnn,
103106
training_dataset=dataset,
104107
validation_dataset=dataset_eval,
105108
loss="penalized_categorical",
@@ -114,28 +117,20 @@ def make_network():
114117
###########################
115118
# Inspecting a fit disRNN #
116119
###########################
117-
118-
# Eval mode runs the network with no noise
119-
def make_network_eval():
120-
return disrnn.HkDisRNN(
121-
update_mlp_shape=update_mlp_shape,
122-
choice_mlp_shape=choice_mlp_shape,
123-
latent_size=latent_size,
124-
obs_size=2,
125-
target_size=2,
126-
eval_mode=True,
127-
)
128-
129-
disrnn.plot_bottlenecks(params, make_network_eval)
130-
disrnn.plot_update_rules(params, make_network_eval)
120+
# Plot bottleneck structure and update rules
121+
plotting.plot_bottlenecks(params, disrnn_config)
122+
plotting.plot_update_rules(params, disrnn_config)
131123

132124
##############################
133125
# Eval disRNN on unseen data #
134126
##############################
135-
127+
# Run the network in noiseless mode to see evolution of states over time
128+
config_noiseless = copy.deepcopy(disrnn_config)
129+
config_noiseless.noiseless_mode = True
130+
make_noiseless_disrnn = lambda: disrnn.HkDisentangledRNN(config_noiseless)
136131
xs, _ = next(dataset_eval)
137132
# pylint: disable-next=unused-variable
138-
_, network_states = rnn_utils.eval_network(make_network_eval, params, xs)
133+
_, network_states = rnn_utils.eval_network(make_noiseless_disrnn, params, xs)
139134

140135

141136
if __name__ == "__main__":

0 commit comments

Comments
 (0)