Skip to content

Commit dc320db

Browse files
authored
Minor changes (#208)
1 parent a5e488f commit dc320db

14 files changed

+126
-64
lines changed

deeptime/markov/_base.py

-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def lagtime(self):
139139
def ck_test(self, models, n_metastable_sets, include_lag0=True, err_est=False, progress=None):
140140
r""" Performs a Chapman Kolmogorov test.
141141
See :meth:`MarkovStateModel.ck_test <deeptime.markov.msm.MarkovStateModel.ck_test>` for more details """
142-
from deeptime.util.validation import ChapmanKolmogorovTest
143142
clustering = self.prior.pcca(n_metastable_sets)
144143
observable = MembershipsObservable(self, clustering, initial_distribution=self.prior.stationary_distribution)
145144
from deeptime.util.validation import ck_test

deeptime/markov/_observables.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numbers
12
from typing import Union
23

34
import numpy as np
@@ -53,7 +54,12 @@ def __call__(self, model, mlag=1, **kw):
5354
return np.eye(self.n_sets)
5455
model = MembershipsObservable._to_markov_model(model)
5556
# otherwise compute or predict them by model.propagate
56-
pk_on_set = np.zeros((self.n_sets, self.n_sets), dtype=float if self.ignore_imaginary_parts else complex)
57+
integer_lag = isinstance(mlag, numbers.Integral)
58+
if self.ignore_imaginary_parts or (model.is_real and integer_lag and np.all(np.isreal(self.P0))):
59+
dtype = float
60+
else:
61+
dtype = complex
62+
pk_on_set = np.zeros((self.n_sets, self.n_sets), dtype=dtype)
5763
# compute observable on prior in case for Bayesian models.
5864
symbols = model.count_model.state_symbols
5965
subset = self._full2active[symbols] # find subset we are now working on

deeptime/markov/msm/_bayesian_msm.py

+61-16
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from math import sqrt
2-
from typing import Optional, Callable
2+
from typing import Optional, Callable, Union, List
33

44
import numpy as np
55

6+
from .._transition_counting import TransitionCountEstimator
67
from ...base import Estimator
78
from ...numeric import is_square_matrix
89
from .._base import _MSMBaseEstimator, BayesianMSMPosterior
@@ -37,17 +38,17 @@ class BayesianMSM(_MSMBaseEstimator):
3738
this case python sparse matrices will be returned by the corresponding functions instead of numpy arrays.
3839
This behavior is suggested for very large numbers of states (e.g. > 4000) because it is likely to be much
3940
more efficient.
40-
confidence : float, optional, default=0.954
41-
Confidence interval. By default two sigma (95.4%) is used. Use 68.3% for one sigma, 99.7% for three sigma.
4241
maxiter : int, optional, default=1000000
4342
Optional parameter with reversible = True, sets the maximum number of iterations before the transition
4443
matrix estimation method exits.
45-
maxerr : float, optional, default = 1e-8
44+
maxerr : float, optional, default=1e-8
4645
Optional parameter with reversible = True. Convergence tolerance for transition matrix estimation. This
4746
specifies the maximum change of the Euclidean norm of relative stationary probabilities
4847
(:math:`x_i = \sum_k x_{ik}`). The relative stationary probability changes
4948
:math:`e_i = (x_i^{(1)} - x_i^{(2)})/(x_i^{(1)} + x_i^{(2)})` are used in order to track changes in small
5049
probabilities. The Euclidean norm of the change vector, :math:`|e_i|_2`, is compared to maxerr.
50+
lagtime : int, optional, default=None
51+
The lagtime that is used when fitting directly from discrete trajectories.
5152
5253
References
5354
----------
@@ -130,14 +131,14 @@ class BayesianMSM(_MSMBaseEstimator):
130131

131132
def __init__(self, n_samples: int = 100, n_steps: int = None, reversible: bool = True,
132133
stationary_distribution_constraint: Optional[np.ndarray] = None,
133-
sparse: bool = False, confidence: float = 0.954, maxiter: int = int(1e6), maxerr: float = 1e-8):
134+
sparse: bool = False, maxiter: int = int(1e6), maxerr: float = 1e-8, lagtime: Optional[int] = None):
134135
super(BayesianMSM, self).__init__(reversible=reversible, sparse=sparse)
135136
self.stationary_distribution_constraint = stationary_distribution_constraint
136137
self.maxiter = maxiter
137138
self.maxerr = maxerr
138139
self.n_samples = n_samples
139140
self.n_steps = n_steps
140-
self.confidence = confidence
141+
self.lagtime = lagtime
141142

142143
@property
143144
def stationary_distribution_constraint(self) -> Optional[np.ndarray]:
@@ -203,18 +204,13 @@ def fit(self, data, callback: Callable = None, **kw):
203204

204205
from deeptime.markov import TransitionCountModel
205206
if isinstance(data, TransitionCountModel) or is_square_matrix(data):
206-
msm = MaximumLikelihoodMSM(
207-
reversible=self.reversible, stationary_distribution_constraint=self.stationary_distribution_constraint,
208-
sparse=self.sparse, maxiter=self.maxiter, maxerr=self.maxerr
209-
).fit(data).fetch_model()
207+
return self.fit_from_counts(data)
210208
elif isinstance(data, MarkovStateModel):
211-
msm = data
209+
return self.fit_from_msm(data, callback=callback, **kw)
212210
else:
213-
raise ValueError("Unsupported input data, can only be count matrix (or TransitionCountModel, "
214-
"TransitionCountEstimator) or a MarkovStateModel instance or an estimator producing "
215-
"Markov state models.")
216-
217-
return self.fit_from_msm(msm, callback=callback, **kw)
211+
if not self.lagtime and 'lagtime' not in kw.keys():
212+
raise ValueError("To fit directly from a discrete timeseries, a lagtime must be provided!")
213+
return self.fit_from_discrete_timeseries(data, kw.pop('lagtime', self.lagtime), callback=callback, **kw)
218214

219215
def sample(self, prior: MarkovStateModel, n_samples: int, n_steps: Optional[int] = None, callback=None):
220216
r""" Performs sampling based on a prior.
@@ -310,6 +306,55 @@ def fit_from_msm(self, msm: MarkovStateModel, callback=None, **kw):
310306
self._model = BayesianMSMPosterior(prior=msm, samples=samples)
311307
return self
312308

309+
def fit_from_discrete_timeseries(self, discrete_timeseries: Union[np.ndarray, List[np.ndarray]],
310+
lagtime: int = None, count_mode: str = 'effective', callback=None, **kw):
311+
r""" Fits a BayesianMSM directly on timeseries data.
312+
313+
Parameters
314+
----------
315+
discrete_timeseries : list of ndarray
316+
Discrete trajectories.
317+
lagtime : int, optional, default=None
318+
The lagtime that is used for estimation. If None, uses the instance's lagtime attribute.
319+
count_mode : str, default='effective'
320+
The counting mode. Should be of the `effective` kind, otherwise the results may be heavily biased.
321+
callback : callable, optional, default=None
322+
Function to be called to indicate progress of sampling.
323+
**kw
324+
Optional keyword parameters.
325+
326+
Returns
327+
-------
328+
self : BayesianMSM
329+
Reference to self.
330+
"""
331+
counts = TransitionCountEstimator(lagtime=lagtime, count_mode=count_mode,
332+
sparse=self.sparse).fit_fetch(discrete_timeseries).submodel_largest()
333+
return self.fit_from_counts(counts, callback=callback, **kw)
334+
335+
def fit_from_counts(self, counts, callback=None, **kw):
336+
r"""Fits a bayesian MSM on a count model or a count matrix.
337+
338+
Parameters
339+
----------
340+
counts : TransitionCountModel or (n, n) ndarray
341+
The transition counts.
342+
callback : callable, optional, default=None
343+
Function that is called to indicate progress of sampling.
344+
**kw
345+
Optional keyword parameters.
346+
347+
Returns
348+
-------
349+
self : BayesianMSM
350+
Reference to self.
351+
"""
352+
msm = MaximumLikelihoodMSM(
353+
reversible=self.reversible, stationary_distribution_constraint=self.stationary_distribution_constraint,
354+
sparse=self.sparse, maxiter=self.maxiter, maxerr=self.maxerr
355+
).fit(counts).fetch_model()
356+
return self.fit_from_msm(msm, callback=callback, **kw)
357+
313358
@deprecated_method("Deprecated in v0.4.1 and will be removed soon, please use model.ck_test.")
314359
def chapman_kolmogorov_validator(self, n_metastable_sets: int, mlags, test_model=None):
315360
r""" Replaced by `deeptime.markov.msm.BayesianMSMPosterior.ck_test`. """

deeptime/markov/msm/_markov_state_model.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _ensure_eigenvalues(self, neig=None):
360360
if m < neig:
361361
# not enough eigenpairs present - recompute:
362362
self._compute_eigenvalues(neig)
363-
except (AttributeError, TypeError) as e:
363+
except (AttributeError, TypeError):
364364
# no eigendecomposition yet - compute:
365365
self._compute_eigenvalues(neig)
366366

@@ -530,6 +530,12 @@ def _transition_matrix_power(self, power):
530530
])
531531
return transition_matrix
532532

533+
@cached_property
534+
def is_real(self):
535+
r""" Checks if all eigenvalues as well as eigenvectors/functions are real. """
536+
return np.all(np.isreal(self.eigenvalues())) and \
537+
np.all(np.isreal(self.eigenvectors_left()) & np.isreal(self.eigenvectors_right()))
538+
533539
def propagate(self, p0, k: int):
534540
r""" Propagates the initial distribution p0 k times
535541

deeptime/markov/msm/_maximum_likelihood_msm.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class MaximumLikelihoodMSM(_MSMBaseEstimator):
5656
Number of counts required to consider two states connected.
5757
lagtime : int, optional, default=None
5858
Optional lagtime that can be provided at estimator level if fitting from timeseries directly.
59+
use_lcc : bool, default=False
60+
If set to true, this will restrict the resulting MSM collection to only contain the largest connected
61+
state-space component.
5962
6063
References
6164
----------
@@ -64,7 +67,8 @@ class MaximumLikelihoodMSM(_MSMBaseEstimator):
6467

6568
def __init__(self, reversible: bool = True, stationary_distribution_constraint: Optional[np.ndarray] = None,
6669
sparse: bool = False, allow_disconnected: bool = False, maxiter: int = int(1e6), maxerr: float = 1e-8,
67-
connectivity_threshold: float = 0, transition_matrix_tolerance: float = 1e-6, lagtime=None):
70+
connectivity_threshold: float = 0, transition_matrix_tolerance: float = 1e-6, lagtime=None,
71+
use_lcc: bool=False):
6872
super(MaximumLikelihoodMSM, self).__init__(reversible=reversible, sparse=sparse)
6973

7074
self.stationary_distribution_constraint = stationary_distribution_constraint
@@ -74,6 +78,7 @@ def __init__(self, reversible: bool = True, stationary_distribution_constraint:
7478
self.connectivity_threshold = connectivity_threshold
7579
self.transition_matrix_tolerance = transition_matrix_tolerance
7680
self.lagtime = lagtime
81+
self.use_lcc = use_lcc
7782

7883
@property
7984
def allow_disconnected(self) -> bool:
@@ -226,6 +231,8 @@ def fit_from_counts(self, counts: Union[np.ndarray, TransitionCountEstimator, Tr
226231
transition_matrices = []
227232
statdists = []
228233
count_models = []
234+
if self.use_lcc:
235+
sets = [sets[0]]
229236
for subset in sets:
230237
try:
231238
sub_counts = counts.submodel(subset)

deeptime/markov/tools/analysis/_assessment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def is_reversible(T, mu=None, tol=1e-15):
105105
mu = stationary_distribution(T)
106106

107107
if sparse.issparse(T):
108-
prod = sparse.construct.diags(mu) * T
108+
prod = sparse.diags(mu) * T
109109
else:
110110
prod = mu[:, None] * T
111111

deeptime/markov/tools/analysis/dense/_correlations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def time_relaxation_direct_by_diagonalization(P, p0, obs, time, rdl=None):
233233
return result
234234

235235

236-
def time_relaxations_direct(P, p0, obs, times=[1]):
236+
def time_relaxations_direct(P, p0, obs, times=(1,)):
237237
r"""Compute time-relaxations of obs with respect of given initial distribution.
238238
239239
relaxation(k) = p0 P^k obs

deeptime/markov/tools/estimation/api.py

+10-25
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from scipy.sparse import coo_matrix
1414
from scipy.sparse import csr_matrix
1515
from scipy.sparse import issparse
16-
from scipy.sparse.sputils import isdense
1716

1817
from deeptime.util.types import ensure_dtraj_list
1918
from . import dense
@@ -378,10 +377,7 @@ def connected_sets(C, directed=True):
378377
[array([0, 1, 2])]
379378
380379
"""
381-
if isdense(C):
382-
return sparse.connectivity.connected_sets(csr_matrix(C), directed=directed)
383-
else:
384-
return sparse.connectivity.connected_sets(C, directed=directed)
380+
return sparse.connectivity.connected_sets(C if issparse(C) else csr_matrix(C), directed=directed)
385381

386382

387383
def largest_connected_set(C, directed=True):
@@ -432,10 +428,7 @@ def largest_connected_set(C, directed=True):
432428
array([0, 1, 2])
433429
434430
"""
435-
if isdense(C):
436-
return sparse.connectivity.largest_connected_set(csr_matrix(C), directed=directed)
437-
else:
438-
return sparse.connectivity.largest_connected_set(C, directed=directed)
431+
return sparse.connectivity.largest_connected_set(C if issparse(C) else csr_matrix(C), directed=directed)
439432

440433

441434
def largest_connected_submatrix(C, directed=True, lcc=None):
@@ -492,10 +485,9 @@ def largest_connected_submatrix(C, directed=True, lcc=None):
492485
[ 0, 0, 4]]...)
493486
494487
"""
495-
if isdense(C):
496-
return sparse.connectivity.largest_connected_submatrix(csr_matrix(C), directed=directed, lcc=lcc).toarray()
497-
else:
498-
return sparse.connectivity.largest_connected_submatrix(C, directed=directed, lcc=lcc)
488+
lcc = sparse.connectivity.largest_connected_submatrix(C if issparse(C) else csr_matrix(C),
489+
directed=directed, lcc=lcc)
490+
return lcc if issparse(C) else lcc.toarray()
499491

500492

501493
def is_connected(C, directed=True):
@@ -542,10 +534,7 @@ def is_connected(C, directed=True):
542534
True
543535
544536
"""
545-
if isdense(C):
546-
return sparse.connectivity.is_connected(csr_matrix(C), directed=directed)
547-
else:
548-
return sparse.connectivity.is_connected(C, directed=directed)
537+
return sparse.connectivity.is_connected(C if issparse(C) else csr_matrix(C), directed=directed)
549538

550539

551540
################################################################################
@@ -591,7 +580,7 @@ def prior_neighbor(C, alpha=0.001):
591580
592581
"""
593582

594-
if isdense(C):
583+
if not issparse(C):
595584
B = sparse.prior.prior_neighbor(csr_matrix(C), alpha=alpha)
596585
return B.toarray()
597586
else:
@@ -633,7 +622,7 @@ def prior_const(C, alpha=0.001):
633622
[0.001, 0.001, 0.001]])
634623
635624
"""
636-
if not isdense(C):
625+
if issparse(C):
637626
warnings.warn("Prior will be a dense matrix for sparse input")
638627
return sparse.prior.prior_const(C, alpha=alpha)
639628

@@ -690,11 +679,7 @@ def prior_rev(C, alpha=-1.0):
690679
[ 0., 0., -1.]])
691680
692681
"""
693-
if isdense(C):
694-
return sparse.prior.prior_rev(C, alpha=alpha)
695-
else:
696-
# warnings.warn("Prior will be a dense matrix for sparse input")
697-
return sparse.prior.prior_rev(C, alpha=alpha)
682+
return sparse.prior.prior_rev(C, alpha=alpha)
698683

699684

700685
################################################################################
@@ -803,7 +788,7 @@ def transition_matrix(C, reversible=False, mu=None, method='auto',
803788
"""
804789
if issparse(C):
805790
sparse_input_type = True
806-
elif isdense(C):
791+
elif isinstance(C, np.ndarray):
807792
sparse_input_type = False
808793
else:
809794
raise NotImplementedError('C has an unknown type.')

deeptime/markov/tools/estimation/sparse/effective_counts.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import numpy as np
88
import scipy.sparse
9-
from scipy.sparse.csr import csr_matrix
109

1110
from threadpoolctl import threadpool_limits
1211

@@ -215,7 +214,7 @@ def statistical_inefficiencies(dtrajs, lag, C=None, truncate_acf=True, mact=2.0,
215214
truncate_acf=truncate_acf, mact=mact)
216215
if callback is not None:
217216
callback(1)
218-
res = csr_matrix((data, (I, J)), shape=C.shape)
217+
res = scipy.sparse.csr_matrix((data, (I, J)), shape=C.shape)
219218
return res
220219

221220

deeptime/markov/tools/flux/api.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as _np
2-
from scipy.sparse import csr_matrix
3-
from scipy.sparse.base import issparse
2+
from scipy.sparse import csr_matrix, issparse
43

54
from deeptime.util.sparse import remove_negative_entries
65

0 commit comments

Comments
 (0)