15
15
"""Example DisRNN workflow: Define a dataset, train network, inspect the fit.
16
16
"""
17
17
18
+ import copy
19
+
18
20
from absl import app
19
21
from absl import flags
20
22
from disentangled_rnns .library import disrnn
23
+ from disentangled_rnns .library import plotting
21
24
from disentangled_rnns .library import rnn_utils
22
25
from disentangled_rnns .library import two_armed_bandits
23
26
import optax
@@ -64,18 +67,18 @@ def main(_) -> None:
64
67
)
65
68
66
69
# Define the disRNN architecture
67
- update_mlp_shape = ( 3 , 5 , 5 )
68
- choice_mlp_shape = ( 2 , 2 )
69
- latent_size = 5
70
-
71
- def make_network ():
72
- return disrnn . HkDisRNN (
73
- update_mlp_shape = update_mlp_shape ,
74
- choice_mlp_shape = choice_mlp_shape ,
75
- latent_size = latent_size ,
76
- obs_size = 2 ,
77
- target_size = 2 ,
78
- )
70
+ disrnn_config = disrnn . DisRnnConfig (
71
+ obs_size = 2 ,
72
+ output_size = 2 ,
73
+ latent_size = 5 ,
74
+ update_net_n_units_per_layer = 8 ,
75
+ update_net_n_layers = 4 ,
76
+ choice_net_n_units_per_layer = 4 ,
77
+ choice_net_n_layers = 2 ,
78
+ x_names = dataset . x_names ,
79
+ y_names = dataset . y_names ,
80
+ )
81
+ make_disrnn = lambda : disrnn . HkDisentangledRNN ( disrnn_config )
79
82
80
83
opt = optax .adam (learning_rate = FLAGS .learning_rate )
81
84
@@ -85,7 +88,7 @@ def make_network():
85
88
86
89
# Warmup training with no information penalty
87
90
params , _ , _ = rnn_utils .train_network (
88
- make_network ,
91
+ make_disrnn ,
89
92
training_dataset = dataset ,
90
93
validation_dataset = dataset_eval ,
91
94
loss = "penalized_categorical" ,
@@ -99,7 +102,7 @@ def make_network():
99
102
100
103
# Additional training using information penalty
101
104
params , _ , _ = rnn_utils .train_network (
102
- make_network ,
105
+ make_disrnn ,
103
106
training_dataset = dataset ,
104
107
validation_dataset = dataset_eval ,
105
108
loss = "penalized_categorical" ,
@@ -114,28 +117,20 @@ def make_network():
114
117
###########################
115
118
# Inspecting a fit disRNN #
116
119
###########################
117
-
118
- # Eval mode runs the network with no noise
119
- def make_network_eval ():
120
- return disrnn .HkDisRNN (
121
- update_mlp_shape = update_mlp_shape ,
122
- choice_mlp_shape = choice_mlp_shape ,
123
- latent_size = latent_size ,
124
- obs_size = 2 ,
125
- target_size = 2 ,
126
- eval_mode = True ,
127
- )
128
-
129
- disrnn .plot_bottlenecks (params , make_network_eval )
130
- disrnn .plot_update_rules (params , make_network_eval )
120
+ # Plot bottleneck structure and update rules
121
+ plotting .plot_bottlenecks (params , disrnn_config )
122
+ plotting .plot_update_rules (params , disrnn_config )
131
123
132
124
##############################
133
125
# Eval disRNN on unseen data #
134
126
##############################
135
-
127
+ # Run the network in noiseless mode to see evolution of states over time
128
+ config_noiseless = copy .deepcopy (disrnn_config )
129
+ config_noiseless .noiseless_mode = True
130
+ make_noiseless_disrnn = lambda : disrnn .HkDisentangledRNN (config_noiseless )
136
131
xs , _ = next (dataset_eval )
137
132
# pylint: disable-next=unused-variable
138
- _ , network_states = rnn_utils .eval_network (make_network_eval , params , xs )
133
+ _ , network_states = rnn_utils .eval_network (make_noiseless_disrnn , params , xs )
139
134
140
135
141
136
if __name__ == "__main__" :
0 commit comments