Skip to content

Commit

Permalink
added first few tests for DMDc
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoppe committed Feb 14, 2025
1 parent 3148bbe commit 2242457
Showing 1 changed file with 124 additions and 0 deletions.
124 changes: 124 additions & 0 deletions heat/decomposition/tests/test_dmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,127 @@ def test_dmd_correctness(self):
# check batch prediction (split = None)
X_batch = ht.random.rand(10, 2 * ht.MPI_WORLD.size, split=None)
Y = dmd.predict(X_batch, [-1, 1, 3])


class TestDMDc(TestCase):
def test_dmdc_setup_and_catch_wrong(self):
# catch wrong inputs
with self.assertRaises(TypeError):
ht.decomposition.DMDc(svd_solver=0)
with self.assertRaises(ValueError):
ht.decomposition.DMDc(svd_solver="Gramian")
with self.assertRaises(ValueError):
ht.decomposition.DMDc(svd_solver="full", svd_rank=3, svd_tol=1e-1)
with self.assertRaises(ValueError):
ht.decomposition.DMDc(svd_solver="full", svd_tol=-0.031415926)
with self.assertRaises(ValueError):
ht.decomposition.DMDc(svd_solver="hierarchical")
with self.assertRaises(ValueError):
ht.decomposition.DMDc(svd_solver="hierarchical", svd_rank=3, svd_tol=1e-1)
with self.assertRaises(ValueError):
ht.decomposition.DMDc(svd_solver="randomized")
with self.assertRaises(ValueError):
ht.decomposition.DMDc(svd_solver="randomized", svd_rank=2, svd_tol=1e-1)
with self.assertRaises(TypeError):
ht.decomposition.DMDc(svd_solver="full", svd_rank=0.1)
with self.assertRaises(ValueError):
ht.decomposition.DMDc(svd_solver="hierarchical", svd_rank=0)
with self.assertRaises(TypeError):
ht.decomposition.DMDc(svd_solver="hierarchical", svd_tol="auto")
with self.assertRaises(ValueError):
ht.decomposition.DMDc(svd_solver="randomized", svd_rank=0)

dmd = ht.decomposition.DMDc(svd_solver="full")
# wrong dimensions of input
with self.assertRaises(ValueError):
dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 2, 2), split=0), ht.zeros((2, 4), split=0))
with self.assertRaises(ValueError):
dmd.fit(ht.zeros((2, 4), split=0), ht.zeros((5 * ht.MPI_WORLD.size, 2, 2), split=0))
# less than two timesteps
with self.assertRaises(ValueError):
dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 1), split=0), ht.zeros((2, 4), split=0))
with self.assertRaises(ValueError):
dmd.fit(ht.zeros((2, 4), split=0), ht.zeros((5 * ht.MPI_WORLD.size, 1), split=0))
# inconsistent number of timesteps
with self.assertRaises(ValueError):
dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 3), split=0), ht.zeros((2, 4), split=0))

# def test_dmdc_functionality_split0(self):
# # check whether the everything works with split=0, various checks are scattered over the different cases
# X = ht.random.randn(10 * ht.MPI_WORLD.size, 10, split=0)
# C = ht.random.randn(10, 10, split=0)
# dmd = ht.decomposition.DMDc(svd_solver="full")
# dmd.fit(X,C)
# self.assertTrue(dmd.rom_eigenmodes_.dtype == ht.complex64)
# self.assertEqual(dmd.rom_eigenmodes_.shape, (dmd.n_modes_, dmd.n_modes_))
# dmd = ht.decomposition.DMD(svd_solver="full", svd_tol=1e-1)
# dmd.fit(X,C)
# self.assertTrue(dmd.rom_basis_.shape[0] == 10 * ht.MPI_WORLD.size)
# dmd = ht.decomposition.DMD(svd_solver="full", svd_rank=3)
# dmd.fit(X,C)
# self.assertTrue(dmd.rom_basis_.shape[1] == 3)
# self.assertTrue(dmd.dmdmodes_.shape == (10 * ht.MPI_WORLD.size, 3))
# dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_rank=3)
# dmd.fit(X,C)
# self.assertTrue(dmd.rom_eigenvalues_.shape == (3,))
# dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_tol=1e-1)
# dmd.fit(X,C)
# # Y = ht.random.randn(10 * ht.MPI_WORLD.size, split=0)
# # Z = dmd.predict_next(Y)
# # self.assertTrue(Z.shape == (10 * ht.MPI_WORLD.size,))
# # self.assertTrue(dmd.rom_eigenvalues_.dtype == ht.complex64)
# # self.assertTrue(dmd.dmdmodes_.dtype == ht.complex64)

# X = ht.random.randn(1000, 10 * ht.MPI_WORLD.size, split=0, dtype=ht.float32)
# dmd = ht.decomposition.DMD(svd_solver="randomized", svd_rank=4)
# dmd.fit(X)
# # Y = ht.random.rand(1000, 2 * ht.MPI_WORLD.size, split=1, dtype=ht.float32)
# # Z = dmd.predict_next(Y, 2)
# # self.assertTrue(Z.dtype == ht.float32)
# # self.assertEqual(Z.shape, Y.shape)

# # # wrong shape of input for prediction
# # with self.assertRaises(ValueError):
# # dmd.predict_next(ht.zeros((100, 4), split=0))
# # with self.assertRaises(ValueError):
# # dmd.predict(ht.zeros((100, 4), split=0), 10)
# # # wrong input for steps in predict
# # with self.assertRaises(TypeError):
# # dmd.predict(
# # ht.zeros((1000, 5), split=0),
# # "this is clearly neither an integer nor a list of integers",
# # )

def test_dmdc_functionality_split1(self):
# check whether everything works with split=1, various checks are scattered over the different cases
X = ht.random.randn(10, 15 * ht.MPI_WORLD.size, split=1, dtype=ht.float64)
C = ht.random.randn(2, 15 * ht.MPI_WORLD.size, split=1, dtype=ht.float64)
dmd = ht.decomposition.DMDc(svd_solver="full")
dmd.fit(X, C)
self.assertTrue(dmd.dmdmodes_.shape[0] == 10)
dmd = ht.decomposition.DMDc(svd_solver="full", svd_tol=1e-1)
dmd.fit(X, C)
dmd = ht.decomposition.DMDc(svd_solver="full", svd_rank=3)
dmd.fit(X, C)
self.assertTrue(dmd.dmdmodes_.shape[1] == 3)
dmd = ht.decomposition.DMDc(svd_solver="hierarchical", svd_rank=3)
dmd.fit(X, C)
self.assertTrue(dmd.rom_transfer_matrix_.shape == (3, 3))
self.assertTrue(dmd.rom_transfer_matrix_.dtype == ht.float64)
dmd = ht.decomposition.DMDc(svd_solver="hierarchical", svd_tol=1e-1)
dmd.fit(X, C)
self.assertTrue(dmd.rom_eigenvalues_.dtype == ht.complex128)
# Y = ht.random.randn(10, 2 * ht.MPI_WORLD.size, split=1)
# Z = dmd.predict_next(Y)
# self.assertTrue(Z.shape == Y.shape)

# X = ht.random.randn(1000, 10 * ht.MPI_WORLD.size, split=0)
# C = ht.random.randn(10, 10 * ht.MPI_WORLD.size, split=0)
# dmd = ht.decomposition.DMDc(svd_solver="randomized", svd_rank=4)
# dmd.fit(X,C)
# self.assertTrue(dmd.rom_eigenmodes_.shape == (4, 4))
# self.assertTrue(dmd.n_modes_ == 4)
# Y = ht.random.randn(1000, 2, split=0, dtype=ht.float64)
# Z = dmd.predict_next(Y)
# self.assertTrue(Z.dtype == Y.dtype)
# self.assertEqual(Z.shape, Y.shape)

0 comments on commit 2242457

Please sign in to comment.