From 26fb07e87ae024f97780ce7c7ca924d2f039b6a7 Mon Sep 17 00:00:00 2001 From: Kevin Miller Date: Sat, 10 May 2025 02:25:35 -0700 Subject: [PATCH] Add a smoke test for disRNN training PiperOrigin-RevId: 757090159 --- disentangled_rnns/library/disrnn_test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/disentangled_rnns/library/disrnn_test.py b/disentangled_rnns/library/disrnn_test.py index cc26cbb..ec7d996 100644 --- a/disentangled_rnns/library/disrnn_test.py +++ b/disentangled_rnns/library/disrnn_test.py @@ -117,6 +117,17 @@ def test_disrnn_output_shape(self): n_trials, self.disrnn_config.latent_size)) + def test_disrnn_trainable(self): + """Smoke test to check that the disRNN can be trained.""" + n_steps = 10 + _, _, _ = rnn_utils.train_network( + make_network=lambda: disrnn.HkDisentangledRNN(self.disrnn_config), + training_dataset=self.q_dataset, + validation_dataset=None, + params=self.disrnn_params, + n_steps=n_steps, + ) + if __name__ == '__main__': absltest.main()