Skip to content

Commit 2242457

Browse files
author
Hoppe
committed
added first few tests for DMDc
1 parent 3148bbe commit 2242457

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

heat/decomposition/tests/test_dmd.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,127 @@ def test_dmd_correctness(self):
217217
# check batch prediction (split = None)
218218
X_batch = ht.random.rand(10, 2 * ht.MPI_WORLD.size, split=None)
219219
Y = dmd.predict(X_batch, [-1, 1, 3])
220+
221+
222+
class TestDMDc(TestCase):
223+
def test_dmdc_setup_and_catch_wrong(self):
224+
# catch wrong inputs
225+
with self.assertRaises(TypeError):
226+
ht.decomposition.DMDc(svd_solver=0)
227+
with self.assertRaises(ValueError):
228+
ht.decomposition.DMDc(svd_solver="Gramian")
229+
with self.assertRaises(ValueError):
230+
ht.decomposition.DMDc(svd_solver="full", svd_rank=3, svd_tol=1e-1)
231+
with self.assertRaises(ValueError):
232+
ht.decomposition.DMDc(svd_solver="full", svd_tol=-0.031415926)
233+
with self.assertRaises(ValueError):
234+
ht.decomposition.DMDc(svd_solver="hierarchical")
235+
with self.assertRaises(ValueError):
236+
ht.decomposition.DMDc(svd_solver="hierarchical", svd_rank=3, svd_tol=1e-1)
237+
with self.assertRaises(ValueError):
238+
ht.decomposition.DMDc(svd_solver="randomized")
239+
with self.assertRaises(ValueError):
240+
ht.decomposition.DMDc(svd_solver="randomized", svd_rank=2, svd_tol=1e-1)
241+
with self.assertRaises(TypeError):
242+
ht.decomposition.DMDc(svd_solver="full", svd_rank=0.1)
243+
with self.assertRaises(ValueError):
244+
ht.decomposition.DMDc(svd_solver="hierarchical", svd_rank=0)
245+
with self.assertRaises(TypeError):
246+
ht.decomposition.DMDc(svd_solver="hierarchical", svd_tol="auto")
247+
with self.assertRaises(ValueError):
248+
ht.decomposition.DMDc(svd_solver="randomized", svd_rank=0)
249+
250+
dmd = ht.decomposition.DMDc(svd_solver="full")
251+
# wrong dimensions of input
252+
with self.assertRaises(ValueError):
253+
dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 2, 2), split=0), ht.zeros((2, 4), split=0))
254+
with self.assertRaises(ValueError):
255+
dmd.fit(ht.zeros((2, 4), split=0), ht.zeros((5 * ht.MPI_WORLD.size, 2, 2), split=0))
256+
# less than two timesteps
257+
with self.assertRaises(ValueError):
258+
dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 1), split=0), ht.zeros((2, 4), split=0))
259+
with self.assertRaises(ValueError):
260+
dmd.fit(ht.zeros((2, 4), split=0), ht.zeros((5 * ht.MPI_WORLD.size, 1), split=0))
261+
# inconsistent number of timesteps
262+
with self.assertRaises(ValueError):
263+
dmd.fit(ht.zeros((5 * ht.MPI_WORLD.size, 3), split=0), ht.zeros((2, 4), split=0))
264+
265+
# def test_dmdc_functionality_split0(self):
266+
# # check whether the everything works with split=0, various checks are scattered over the different cases
267+
# X = ht.random.randn(10 * ht.MPI_WORLD.size, 10, split=0)
268+
# C = ht.random.randn(10, 10, split=0)
269+
# dmd = ht.decomposition.DMDc(svd_solver="full")
270+
# dmd.fit(X,C)
271+
# self.assertTrue(dmd.rom_eigenmodes_.dtype == ht.complex64)
272+
# self.assertEqual(dmd.rom_eigenmodes_.shape, (dmd.n_modes_, dmd.n_modes_))
273+
# dmd = ht.decomposition.DMD(svd_solver="full", svd_tol=1e-1)
274+
# dmd.fit(X,C)
275+
# self.assertTrue(dmd.rom_basis_.shape[0] == 10 * ht.MPI_WORLD.size)
276+
# dmd = ht.decomposition.DMD(svd_solver="full", svd_rank=3)
277+
# dmd.fit(X,C)
278+
# self.assertTrue(dmd.rom_basis_.shape[1] == 3)
279+
# self.assertTrue(dmd.dmdmodes_.shape == (10 * ht.MPI_WORLD.size, 3))
280+
# dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_rank=3)
281+
# dmd.fit(X,C)
282+
# self.assertTrue(dmd.rom_eigenvalues_.shape == (3,))
283+
# dmd = ht.decomposition.DMD(svd_solver="hierarchical", svd_tol=1e-1)
284+
# dmd.fit(X,C)
285+
# # Y = ht.random.randn(10 * ht.MPI_WORLD.size, split=0)
286+
# # Z = dmd.predict_next(Y)
287+
# # self.assertTrue(Z.shape == (10 * ht.MPI_WORLD.size,))
288+
# # self.assertTrue(dmd.rom_eigenvalues_.dtype == ht.complex64)
289+
# # self.assertTrue(dmd.dmdmodes_.dtype == ht.complex64)
290+
291+
# X = ht.random.randn(1000, 10 * ht.MPI_WORLD.size, split=0, dtype=ht.float32)
292+
# dmd = ht.decomposition.DMD(svd_solver="randomized", svd_rank=4)
293+
# dmd.fit(X)
294+
# # Y = ht.random.rand(1000, 2 * ht.MPI_WORLD.size, split=1, dtype=ht.float32)
295+
# # Z = dmd.predict_next(Y, 2)
296+
# # self.assertTrue(Z.dtype == ht.float32)
297+
# # self.assertEqual(Z.shape, Y.shape)
298+
299+
# # # wrong shape of input for prediction
300+
# # with self.assertRaises(ValueError):
301+
# # dmd.predict_next(ht.zeros((100, 4), split=0))
302+
# # with self.assertRaises(ValueError):
303+
# # dmd.predict(ht.zeros((100, 4), split=0), 10)
304+
# # # wrong input for steps in predict
305+
# # with self.assertRaises(TypeError):
306+
# # dmd.predict(
307+
# # ht.zeros((1000, 5), split=0),
308+
# # "this is clearly neither an integer nor a list of integers",
309+
# # )
310+
311+
def test_dmdc_functionality_split1(self):
312+
# check whether everything works with split=1, various checks are scattered over the different cases
313+
X = ht.random.randn(10, 15 * ht.MPI_WORLD.size, split=1, dtype=ht.float64)
314+
C = ht.random.randn(2, 15 * ht.MPI_WORLD.size, split=1, dtype=ht.float64)
315+
dmd = ht.decomposition.DMDc(svd_solver="full")
316+
dmd.fit(X, C)
317+
self.assertTrue(dmd.dmdmodes_.shape[0] == 10)
318+
dmd = ht.decomposition.DMDc(svd_solver="full", svd_tol=1e-1)
319+
dmd.fit(X, C)
320+
dmd = ht.decomposition.DMDc(svd_solver="full", svd_rank=3)
321+
dmd.fit(X, C)
322+
self.assertTrue(dmd.dmdmodes_.shape[1] == 3)
323+
dmd = ht.decomposition.DMDc(svd_solver="hierarchical", svd_rank=3)
324+
dmd.fit(X, C)
325+
self.assertTrue(dmd.rom_transfer_matrix_.shape == (3, 3))
326+
self.assertTrue(dmd.rom_transfer_matrix_.dtype == ht.float64)
327+
dmd = ht.decomposition.DMDc(svd_solver="hierarchical", svd_tol=1e-1)
328+
dmd.fit(X, C)
329+
self.assertTrue(dmd.rom_eigenvalues_.dtype == ht.complex128)
330+
# Y = ht.random.randn(10, 2 * ht.MPI_WORLD.size, split=1)
331+
# Z = dmd.predict_next(Y)
332+
# self.assertTrue(Z.shape == Y.shape)
333+
334+
# X = ht.random.randn(1000, 10 * ht.MPI_WORLD.size, split=0)
335+
# C = ht.random.randn(10, 10 * ht.MPI_WORLD.size, split=0)
336+
# dmd = ht.decomposition.DMDc(svd_solver="randomized", svd_rank=4)
337+
# dmd.fit(X,C)
338+
# self.assertTrue(dmd.rom_eigenmodes_.shape == (4, 4))
339+
# self.assertTrue(dmd.n_modes_ == 4)
340+
# Y = ht.random.randn(1000, 2, split=0, dtype=ht.float64)
341+
# Z = dmd.predict_next(Y)
342+
# self.assertTrue(Z.dtype == Y.dtype)
343+
# self.assertEqual(Z.shape, Y.shape)

0 commit comments

Comments
 (0)