We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a19598f commit 42eb865Copy full SHA for 42eb865
disentangled_rnns/library/disrnn_test.py
@@ -117,6 +117,17 @@ def test_disrnn_output_shape(self):
117
n_trials,
118
self.disrnn_config.latent_size))
119
120
+ def test_disrnn_trainable(self):
121
+ """Smoke test to check that the disRNN can be trained."""
122
+ n_steps = 10
123
+ _, _, _ = rnn_utils.train_network(
124
+ make_network=lambda: disrnn.HkDisentangledRNN(self.disrnn_config),
125
+ training_dataset=self.q_dataset,
126
+ validation_dataset=None,
127
+ params=self.disrnn_params,
128
+ n_steps=n_steps,
129
+ )
130
+
131
132
if __name__ == '__main__':
133
absltest.main()
0 commit comments