|
1 | 1 | from math import sqrt
|
2 |
| -from typing import Optional, Callable |
| 2 | +from typing import Optional, Callable, Union, List |
3 | 3 |
|
4 | 4 | import numpy as np
|
5 | 5 |
|
| 6 | +from .._transition_counting import TransitionCountEstimator |
6 | 7 | from ...base import Estimator
|
7 | 8 | from ...numeric import is_square_matrix
|
8 | 9 | from .._base import _MSMBaseEstimator, BayesianMSMPosterior
|
@@ -37,17 +38,17 @@ class BayesianMSM(_MSMBaseEstimator):
|
37 | 38 | this case python sparse matrices will be returned by the corresponding functions instead of numpy arrays.
|
38 | 39 | This behavior is suggested for very large numbers of states (e.g. > 4000) because it is likely to be much
|
39 | 40 | more efficient.
|
40 |
| - confidence : float, optional, default=0.954 |
41 |
| - Confidence interval. By default two sigma (95.4%) is used. Use 68.3% for one sigma, 99.7% for three sigma. |
42 | 41 | maxiter : int, optional, default=1000000
|
43 | 42 | Optional parameter with reversible = True, sets the maximum number of iterations before the transition
|
44 | 43 | matrix estimation method exits.
|
45 |
| - maxerr : float, optional, default = 1e-8 |
| 44 | + maxerr : float, optional, default=1e-8 |
46 | 45 | Optional parameter with reversible = True. Convergence tolerance for transition matrix estimation. This
|
47 | 46 | specifies the maximum change of the Euclidean norm of relative stationary probabilities
|
48 | 47 | (:math:`x_i = \sum_k x_{ik}`). The relative stationary probability changes
|
49 | 48 | :math:`e_i = (x_i^{(1)} - x_i^{(2)})/(x_i^{(1)} + x_i^{(2)})` are used in order to track changes in small
|
50 | 49 | probabilities. The Euclidean norm of the change vector, :math:`|e_i|_2`, is compared to maxerr.
|
| 50 | + lagtime : int, optional, default=None |
| 51 | + The lagtime that is used when fitting directly from discrete trajectories. |
51 | 52 |
|
52 | 53 | References
|
53 | 54 | ----------
|
@@ -130,14 +131,14 @@ class BayesianMSM(_MSMBaseEstimator):
|
130 | 131 |
|
131 | 132 | def __init__(self, n_samples: int = 100, n_steps: int = None, reversible: bool = True,
|
132 | 133 | stationary_distribution_constraint: Optional[np.ndarray] = None,
|
133 |
| - sparse: bool = False, confidence: float = 0.954, maxiter: int = int(1e6), maxerr: float = 1e-8): |
| 134 | + sparse: bool = False, maxiter: int = int(1e6), maxerr: float = 1e-8, lagtime: Optional[int] = None): |
134 | 135 | super(BayesianMSM, self).__init__(reversible=reversible, sparse=sparse)
|
135 | 136 | self.stationary_distribution_constraint = stationary_distribution_constraint
|
136 | 137 | self.maxiter = maxiter
|
137 | 138 | self.maxerr = maxerr
|
138 | 139 | self.n_samples = n_samples
|
139 | 140 | self.n_steps = n_steps
|
140 |
| - self.confidence = confidence |
| 141 | + self.lagtime = lagtime |
141 | 142 |
|
142 | 143 | @property
|
143 | 144 | def stationary_distribution_constraint(self) -> Optional[np.ndarray]:
|
@@ -203,18 +204,13 @@ def fit(self, data, callback: Callable = None, **kw):
|
203 | 204 |
|
204 | 205 | from deeptime.markov import TransitionCountModel
|
205 | 206 | if isinstance(data, TransitionCountModel) or is_square_matrix(data):
|
206 |
| - msm = MaximumLikelihoodMSM( |
207 |
| - reversible=self.reversible, stationary_distribution_constraint=self.stationary_distribution_constraint, |
208 |
| - sparse=self.sparse, maxiter=self.maxiter, maxerr=self.maxerr |
209 |
| - ).fit(data).fetch_model() |
| 207 | + return self.fit_from_counts(data) |
210 | 208 | elif isinstance(data, MarkovStateModel):
|
211 |
| - msm = data |
| 209 | + return self.fit_from_msm(data, callback=callback, **kw) |
212 | 210 | else:
|
213 |
| - raise ValueError("Unsupported input data, can only be count matrix (or TransitionCountModel, " |
214 |
| - "TransitionCountEstimator) or a MarkovStateModel instance or an estimator producing " |
215 |
| - "Markov state models.") |
216 |
| - |
217 |
| - return self.fit_from_msm(msm, callback=callback, **kw) |
| 211 | + if not self.lagtime and 'lagtime' not in kw.keys(): |
| 212 | + raise ValueError("To fit directly from a discrete timeseries, a lagtime must be provided!") |
| 213 | + return self.fit_from_discrete_timeseries(data, kw.pop('lagtime', self.lagtime), callback=callback, **kw) |
218 | 214 |
|
219 | 215 | def sample(self, prior: MarkovStateModel, n_samples: int, n_steps: Optional[int] = None, callback=None):
|
220 | 216 | r""" Performs sampling based on a prior.
|
@@ -310,6 +306,55 @@ def fit_from_msm(self, msm: MarkovStateModel, callback=None, **kw):
|
310 | 306 | self._model = BayesianMSMPosterior(prior=msm, samples=samples)
|
311 | 307 | return self
|
312 | 308 |
|
| 309 | + def fit_from_discrete_timeseries(self, discrete_timeseries: Union[np.ndarray, List[np.ndarray]], |
| 310 | + lagtime: int = None, count_mode: str = 'effective', callback=None, **kw): |
| 311 | + r""" Fits a BayesianMSM directly on timeseries data. |
| 312 | +
|
| 313 | + Parameters |
| 314 | + ---------- |
| 315 | + discrete_timeseries : list of ndarray |
| 316 | + Discrete trajectories. |
| 317 | + lagtime : int, optional, default=None |
| 318 | + The lagtime that is used for estimation. If None, uses the instance's lagtime attribute. |
| 319 | + count_mode : str, default='effective' |
| 320 | + The counting mode. Should be of the `effective` kind, otherwise the results may be heavily biased. |
| 321 | + callback : callable, optional, default=None |
| 322 | + Function to be called to indicate progress of sampling. |
| 323 | + **kw |
| 324 | + Optional keyword parameters. |
| 325 | +
|
| 326 | + Returns |
| 327 | + ------- |
| 328 | + self : BayesianMSM |
| 329 | + Reference to self. |
| 330 | + """ |
| 331 | + counts = TransitionCountEstimator(lagtime=lagtime, count_mode=count_mode, |
| 332 | + sparse=self.sparse).fit_fetch(discrete_timeseries).submodel_largest() |
| 333 | + return self.fit_from_counts(counts, callback=callback, **kw) |
| 334 | + |
| 335 | + def fit_from_counts(self, counts, callback=None, **kw): |
| 336 | + r"""Fits a bayesian MSM on a count model or a count matrix. |
| 337 | +
|
| 338 | + Parameters |
| 339 | + ---------- |
| 340 | + counts : TransitionCountModel or (n, n) ndarray |
| 341 | + The transition counts. |
| 342 | + callback : callable, optional, default=None |
| 343 | + Function that is called to indicate progress of sampling. |
| 344 | + **kw |
| 345 | + Optional keyword parameters. |
| 346 | +
|
| 347 | + Returns |
| 348 | + ------- |
| 349 | + self : BayesianMSM |
| 350 | + Reference to self. |
| 351 | + """ |
| 352 | + msm = MaximumLikelihoodMSM( |
| 353 | + reversible=self.reversible, stationary_distribution_constraint=self.stationary_distribution_constraint, |
| 354 | + sparse=self.sparse, maxiter=self.maxiter, maxerr=self.maxerr |
| 355 | + ).fit(counts).fetch_model() |
| 356 | + return self.fit_from_msm(msm, callback=callback, **kw) |
| 357 | + |
313 | 358 | @deprecated_method("Deprecated in v0.4.1 and will be removed soon, please use model.ck_test.")
|
314 | 359 | def chapman_kolmogorov_validator(self, n_metastable_sets: int, mlags, test_model=None):
|
315 | 360 | r""" Replaced by `deeptime.markov.msm.BayesianMSMPosterior.ck_test`. """
|
|
0 commit comments