Skip to content

Commit ac918cf

Browse files
kevin-j-millercopybara-github
authored andcommitted
Add a smoke test for disRNN training
PiperOrigin-RevId: 756730696
1 parent a19598f commit ac918cf

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

disentangled_rnns/library/disrnn_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,17 @@ def test_disrnn_output_shape(self):
117117
n_trials,
118118
self.disrnn_config.latent_size))
119119

120+
def test_disrnn_trainable(self):
121+
"""Smoke test to check that the disRNN can be trained."""
122+
n_steps = 10
123+
_, _, _ = rnn_utils.train_network(
124+
make_network=lambda: disrnn.HkDisentangledRNN(self.disrnn_config),
125+
training_dataset=self.q_dataset,
126+
validation_dataset=None,
127+
params=self.disrnn_params,
128+
n_steps=n_steps,
129+
)
130+
120131

121132
if __name__ == '__main__':
122133
absltest.main()

0 commit comments

Comments
 (0)