diff --git a/disentangled_rnns/library/disrnn_test.py b/disentangled_rnns/library/disrnn_test.py index cc26cbb..ec7d996 100644 --- a/disentangled_rnns/library/disrnn_test.py +++ b/disentangled_rnns/library/disrnn_test.py @@ -117,6 +117,17 @@ def test_disrnn_output_shape(self): n_trials, self.disrnn_config.latent_size)) + def test_disrnn_trainable(self): + """Smoke test to check that the disRNN can be trained.""" + n_steps = 10 + _, _, _ = rnn_utils.train_network( + make_network=lambda: disrnn.HkDisentangledRNN(self.disrnn_config), + training_dataset=self.q_dataset, + validation_dataset=None, + params=self.disrnn_params, + n_steps=n_steps, + ) + if __name__ == '__main__': absltest.main()