Skip to content

Commit 16ee03e

Browse files
authored
var_cutoff can be disabled in covariance koopman models. (#255)
1 parent 153d9f6 commit 16ee03e

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

Diff for: deeptime/decomposition/_koopman.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -344,15 +344,16 @@ def var_cutoff(self) -> Optional[float]:
344344
precedence over the :meth:`dim` parameter.
345345
346346
:getter: Yields the current variance cutoff.
347-
:setter: Sets a new variance cutoff
347+
:setter: Sets a new variance cutoff or disables variance cutoff by setting the value to `None`.
348348
:type: float or None
349349
"""
350350
return self._var_cutoff
351351

352352
@var_cutoff.setter
353-
def var_cutoff(self, value):
354-
assert 0 < value <= 1., "Invalid dimension parameter, if it is given in terms of a variance cutoff, " \
355-
"it can only be in the interval (0, 1]."
353+
def var_cutoff(self, value: Optional[float]):
354+
assert value is None or 0 < value <= 1., \
355+
"Invalid dimension parameter, if it is given in terms of a variance cutoff, " \
356+
"it can only be in the interval (0, 1]."
356357
self._var_cutoff = value
357358
self._update_output_dimension()
358359

Diff for: tests/decomposition/test_tica.py

+17
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import pytest
11+
from numpy.testing import assert_, assert_equal
1112

1213
from deeptime.covariance import Covariance
1314
from deeptime.data import ellipsoids
@@ -16,6 +17,22 @@
1617
from deeptime.numeric import ZeroRankError
1718

1819

20+
def test_update_projection_dimension():
21+
# tests for https://github.com/deeptime-ml/deeptime/issues/254
22+
data = np.random.normal(size=(1000, 50))
23+
model = TICA(lagtime=1, var_cutoff=.1).fit_fetch(data)
24+
assert_equal(model.var_cutoff, .1)
25+
assert_(model.transform(data).shape[1] <= 10)
26+
model.var_cutoff = None
27+
assert_equal(model.var_cutoff, None)
28+
model.dim = 5
29+
assert_equal(model.dim, 5)
30+
assert_(model.transform(data).shape[1] <= 5)
31+
model.dim = 1
32+
assert_equal(model.dim, 1)
33+
assert_equal(model.transform(data).shape[1], 1)
34+
35+
1936
def test_fit_reset():
2037
lag = 100
2138
np.random.seed(0)

0 commit comments

Comments
 (0)