diff --git a/heat/decomposition/tests/test_dmd.py b/heat/decomposition/tests/test_dmd.py index 9f4fa14e9..28da35856 100644 --- a/heat/decomposition/tests/test_dmd.py +++ b/heat/decomposition/tests/test_dmd.py @@ -217,3 +217,127 @@ def test_dmd_correctness(self): # check batch prediction (split = None) X_batch = ht.random.rand(10, 2 * ht.MPI_WORLD.size, split=None) Y = dmd.predict(X_batch, [-1, 1, 3]) + + +class TestDMDc(TestCase): + def test_dmdc_setup_and_catch_wrong(self): + # catch wrong inputs + with self.assertRaises(TypeError): + ht.decomposition.DMDc(svd_solver=0) + with self.assertRaises(ValueError): + ht.decomposition.DMDc(svd_solver="Gramian") + with self.assertRaises(ValueError): + ht.decomposition.DMDc(svd_solver="full", svd_rank=3, svd_tol=1e-1) + with self.assertRaises(ValueError): + ht.decomposition.DMDc(svd_solver="full", svd_tol=-0.031415926) + with self.assertRaises(ValueError): + ht.decomposition.DMDc(svd_solver="hierarchical") + with self.assertRaises(ValueError): + ht.decomposition.DMDc(svd_solver="hierarchical", svd_rank=3, svd_tol=1e-1) + with self.assertRaises(ValueError): + ht.decomposition.DMDc(svd_solver="randomized") + with self.assertRaises(ValueError): + ht.decomposition.DMDc(svd_solver="randomized", svd_rank=2, svd_tol=1e-1) + with self.assertRaises(TypeError): + ht.decomposition.DMDc(svd_solver="full", svd_rank=0.1) + with self.assertRaises(ValueError): + ht.decomposition.DMDc(svd_solver="hierarchical", svd_rank=0) + with self.assertRaises(TypeError): + ht.decomposition.DMDc(svd_solver="hierarchical", svd_tol="auto") + with self.assertRaises(ValueError): + ht.decomposition.DMDc(svd_solver="randomized", svd_rank=0) + + dmd = ht.decomposition.DMDc(svd_solver="full") + # wrong dimensions of input + with self.assertRaises(ValueError): + dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 2, 2), split=0), ht.zeros((2, 4), split=0)) + with self.assertRaises(ValueError): + dmd.fit(ht.zeros((2, 4), split=0), ht.zeros((5 * ht.MPI_WORLD.size, 2, 2), split=0)) + # less than two timesteps + with self.assertRaises(ValueError): + dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 1), split=0), ht.zeros((2, 4), split=0)) + with self.assertRaises(ValueError): + dmd.fit(ht.zeros((2, 4), split=0), ht.zeros((5 * ht.MPI_WORLD.size, 1), split=0)) + # inconsistent number of timesteps + with self.assertRaises(ValueError): + dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 3), split=0), ht.zeros((2, 4), split=0)) + + # def test_dmdc_functionality_split0(self): + # # check whether the everything works with split=0, various checks are scattered over the different cases + # X = ht.random.randn(10 * ht.MPI_WORLD.size, 10, split=0) + # C = ht.random.randn(10, 10, split=0) + # dmd = ht.decomposition.DMDc(svd_solver="full") + # dmd.fit(X,C) + # self.assertTrue(dmd.rom_eigenmodes_.dtype == ht.complex64) + # self.assertEqual(dmd.rom_eigenmodes_.shape, (dmd.n_modes_, dmd.n_modes_)) + # dmd = ht.decomposition.DMD(svd_solver="full", svd_tol=1e-1) + # dmd.fit(X,C) + # self.assertTrue(dmd.rom_basis_.shape[0] == 10 * ht.MPI_WORLD.size) + # dmd = ht.decomposition.DMD(svd_solver="full", svd_rank=3) + # dmd.fit(X,C) + # self.assertTrue(dmd.rom_basis_.shape[1] == 3) + # self.assertTrue(dmd.dmdmodes_.shape == (10 * ht.MPI_WORLD.size, 3)) + # dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_rank=3) + # dmd.fit(X,C) + # self.assertTrue(dmd.rom_eigenvalues_.shape == (3,)) + # dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_tol=1e-1) + # dmd.fit(X,C) + # # Y = ht.random.randn(10 * ht.MPI_WORLD.size, split=0) + # # Z = dmd.predict_next(Y) + # # self.assertTrue(Z.shape == (10 * ht.MPI_WORLD.size,)) + # # self.assertTrue(dmd.rom_eigenvalues_.dtype == ht.complex64) + # # self.assertTrue(dmd.dmdmodes_.dtype == ht.complex64) + + # X = ht.random.randn(1000, 10 * ht.MPI_WORLD.size, split=0, dtype=ht.float32) + # dmd = ht.decomposition.DMD(svd_solver="randomized", svd_rank=4) + # dmd.fit(X) + # # Y = ht.random.rand(1000, 2 * ht.MPI_WORLD.size, split=1, dtype=ht.float32) + # # Z = dmd.predict_next(Y, 2) + # # self.assertTrue(Z.dtype == ht.float32) + # # self.assertEqual(Z.shape, Y.shape) + + # # # wrong shape of input for prediction + # # with self.assertRaises(ValueError): + # # dmd.predict_next(ht.zeros((100, 4), split=0)) + # # with self.assertRaises(ValueError): + # # dmd.predict(ht.zeros((100, 4), split=0), 10) + # # # wrong input for steps in predict + # # with self.assertRaises(TypeError): + # # dmd.predict( + # # ht.zeros((1000, 5), split=0), + # # "this is clearly neither an integer nor a list of integers", + # # ) + + def test_dmdc_functionality_split1(self): + # check whether everything works with split=1, various checks are scattered over the different cases + X = ht.random.randn(10, 15 * ht.MPI_WORLD.size, split=1, dtype=ht.float64) + C = ht.random.randn(2, 15 * ht.MPI_WORLD.size, split=1, dtype=ht.float64) + dmd = ht.decomposition.DMDc(svd_solver="full") + dmd.fit(X, C) + self.assertTrue(dmd.dmdmodes_.shape[0] == 10) + dmd = ht.decomposition.DMDc(svd_solver="full", svd_tol=1e-1) + dmd.fit(X, C) + dmd = ht.decomposition.DMDc(svd_solver="full", svd_rank=3) + dmd.fit(X, C) + self.assertTrue(dmd.dmdmodes_.shape[1] == 3) + dmd = ht.decomposition.DMDc(svd_solver="hierarchical", svd_rank=3) + dmd.fit(X, C) + self.assertTrue(dmd.rom_transfer_matrix_.shape == (3, 3)) + self.assertTrue(dmd.rom_transfer_matrix_.dtype == ht.float64) + dmd = ht.decomposition.DMDc(svd_solver="hierarchical", svd_tol=1e-1) + dmd.fit(X, C) + self.assertTrue(dmd.rom_eigenvalues_.dtype == ht.complex128) + # Y = ht.random.randn(10, 2 * ht.MPI_WORLD.size, split=1) + # Z = dmd.predict_next(Y) + # self.assertTrue(Z.shape == Y.shape) + + # X = ht.random.randn(1000, 10 * ht.MPI_WORLD.size, split=0) + # C = ht.random.randn(10, 10 * ht.MPI_WORLD.size, split=0) + # dmd = ht.decomposition.DMDc(svd_solver="randomized", svd_rank=4) + # dmd.fit(X,C) + # self.assertTrue(dmd.rom_eigenmodes_.shape == (4, 4)) + # self.assertTrue(dmd.n_modes_ == 4) + # Y = ht.random.randn(1000, 2, split=0, dtype=ht.float64) + # Z = dmd.predict_next(Y) + # self.assertTrue(Z.dtype == Y.dtype) + # self.assertEqual(Z.shape, Y.shape)