Skip to content

Commit 2f16e0f

Browse files
authored
Cluster progress (#195)
* fix custom metric example * make Index projection consistently compile-time enabled * rework progress handling * tram toc update * doc fixes and tests * notebooks update
1 parent 5a785b2 commit 2f16e0f

File tree

19 files changed

+229
-79
lines changed

19 files changed

+229
-79
lines changed

CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ add_subdirectory(deeptime/markov/msm/tram/_bindings)
3232
add_subdirectory(deeptime/markov/tools/estimation/dense/_bindings)
3333
add_subdirectory(deeptime/markov/tools/estimation/sparse/_bindings)
3434

35+
add_subdirectory(examples/clustering_custom_metric)
36+
3537
if(DEEPTIME_BUILD_CPP_TESTS)
3638
add_subdirectory(tests)
3739
endif()

deeptime/clustering/_kmeans.py

+40-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import random
22
import warnings
3+
from contextlib import nullcontext
34
from typing import Optional
45

56
import numpy as np
67

78
from ..base import EstimatorTransformer
89
from ._cluster_model import ClusterModel
910
from . import metrics
11+
from ..util.callbacks import ProgressCallback
1012

1113
from ..util.parallel import handle_n_jobs
1214

@@ -173,6 +175,10 @@ class KMeans(EstimatorTransformer):
173175
initial_centers: None or np.ndarray[k, dim], default=None
174176
This is used to resume the kmeans iteration. Note, that if this is set, the init_strategy is ignored and
175177
the centers are directly passed to the kmeans iteration algorithm.
178+
progress : object
179+
Progress bar object that `KMeans` will call to indicate progress to the user. Tested for a tqdm progress bar.
180+
The interface is checked
181+
via :meth:`supports_progress_interface <deeptime.util.callbacks.supports_progress_interface>`.
176182
177183
References
178184
----------
@@ -186,7 +192,7 @@ class KMeans(EstimatorTransformer):
186192

187193
def __init__(self, n_clusters: int, max_iter: int = 500, metric='euclidean',
188194
tolerance=1e-5, init_strategy: str = 'kmeans++', fixed_seed=False,
189-
n_jobs=None, initial_centers=None):
195+
n_jobs=None, initial_centers=None, progress=None):
190196
super(KMeans, self).__init__()
191197

192198
self.n_clusters = n_clusters
@@ -198,6 +204,7 @@ def __init__(self, n_clusters: int, max_iter: int = 500, metric='euclidean',
198204
self.random_state = np.random.RandomState(self.fixed_seed)
199205
self.n_jobs = handle_n_jobs(n_jobs)
200206
self.initial_centers = initial_centers
207+
self.progress = progress
201208

202209
@property
203210
def initial_centers(self) -> Optional[np.ndarray]:
@@ -421,14 +428,30 @@ def fit(self, data, initial_centers=None, callback_init_centers=None, callback_l
421428
if initial_centers is not None:
422429
self.initial_centers = initial_centers
423430
if self.initial_centers is None:
424-
self.initial_centers = self._pick_initial_centers(data, self.init_strategy, n_jobs, callback_init_centers)
431+
if self.progress is not None:
432+
callback = KMeansCallback(self.progress, "KMeans++ initialization", self.n_clusters,
433+
callback_init_centers)
434+
context = callback
435+
else:
436+
callback = callback_init_centers
437+
context = nullcontext()
438+
with context:
439+
self.initial_centers = self._pick_initial_centers(data, self.init_strategy, n_jobs, callback)
425440

426441
# run k-means with all the data
427442
converged = False
428443
impl = metrics[self.metric]
429-
cluster_centers, code, iterations, cost = impl.kmeans.cluster_loop(
430-
data, self.initial_centers.copy(), n_jobs, self.max_iter,
431-
self.tolerance, callback_loop)
444+
445+
if self.progress is not None:
446+
callback = KMeansCallback(self.progress, "KMeans iterations", self.max_iter, callback_loop)
447+
context = callback
448+
else:
449+
callback = callback_loop
450+
context = nullcontext()
451+
with context:
452+
cluster_centers, code, iterations, cost = impl.kmeans.cluster_loop(
453+
data, self.initial_centers.copy(), n_jobs, self.max_iter,
454+
self.tolerance, callback)
432455
if code == 0:
433456
converged = True
434457
else:
@@ -526,3 +549,15 @@ def partial_fit(self, data, n_jobs=None):
526549
self._model._converged = True
527550

528551
return self
552+
553+
554+
class KMeansCallback(ProgressCallback):
555+
556+
def __init__(self, progress, description, total, parent_callback=None):
557+
super().__init__(progress, description=description, total=total)
558+
self._parent_callback = parent_callback
559+
560+
def __call__(self, *args, **kw):
561+
super().__call__(*args, **kw)
562+
if self._parent_callback is not None:
563+
self._parent_callback(*args, **kw)

deeptime/markov/msm/tram/_tram.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ class TRAM(_MSMBaseEstimator):
8686
`track_log_likelihoods=true`, the log-likelihood are also stored. If `callback_interval=0`, no call to the
8787
callback function is done.
8888
progress : object
89-
Progress bar object that `TRAM` will call to indicate progress to the user.
90-
Tested for a tqdm progress bar. Should implement `update()` and `close()` and have `total` and `desc`
91-
properties.
89+
Progress bar object that `TRAM` will call to indicate progress to the user. Tested for a tqdm progress bar.
90+
The interface is checked
91+
via :meth:`supports_progress_interface <deeptime.util.callbacks.supports_progress_interface>`.
9292
9393
See also
9494
--------
@@ -192,8 +192,7 @@ def fit(self, data, model=None, *args, **kw):
192192
return self
193193

194194
def _run_estimation(self, tram_input):
195-
""" Estimate the free energies using self-consistent iteration as described in the TRAM paper.
196-
"""
195+
""" Estimate the free energies using self-consistent iteration as described in the TRAM paper. """
197196
with TRAMCallback(self.progress, self.maxiter, self.log_likelihoods, self.increments,
198197
self.callback_interval > 0) as callback:
199198
self._tram_estimator.estimate(tram_input, self.maxiter, self.maxerr,
@@ -202,26 +201,28 @@ def _run_estimation(self, tram_input):
202201

203202
if callback.last_increment > self.maxerr:
204203
warnings.warn(
205-
f"TRAM did not converge after {self.maxiter} iteration. Last increment: {callback.last_increment}",
206-
ConvergenceWarning)
204+
f"TRAM did not converge after {self.maxiter} iteration(s). "
205+
f"Last increment: {callback.last_increment}", ConvergenceWarning)
207206

208207

209-
class TRAMCallback(callbacks.Callback):
208+
class TRAMCallback(callbacks.ProgressCallback):
210209
"""Callback for the TRAM estimate process. Increments a progress bar and optionally saves iteration increments and
211210
log likelihoods to a list.
212211
213212
Parameters
214213
----------
215214
log_likelihoods_list : list, optional
216215
A list to append the log-likelihoods to that are passed to the callback.__call__() method.
216+
total : int
217+
Maximum number of callbacks.
217218
increments : list, optional
218219
A list to append the increments to that are passed to the callback.__call__() method.
219220
store_convergence_info : bool, default=False
220221
If True, log_likelihoods and increments are appended to their respective lists each time callback.__call__() is
221222
called. If false, no values are appended, only the last increment is stored.
222223
"""
223-
def __init__(self, progress, n_iter, log_likelihoods_list=None, increments=None, store_convergence_info=False):
224-
super().__init__(progress, n_iter=n_iter, display_text="Running TRAM estimate")
224+
def __init__(self, progress, total, log_likelihoods_list=None, increments=None, store_convergence_info=False):
225+
super().__init__(progress, total=total, description="Running TRAM estimate")
225226
self.log_likelihoods = log_likelihoods_list
226227
self.increments = increments
227228
self.store_convergence_info = store_convergence_info

deeptime/markov/msm/tram/_tram_dataset.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,10 @@ def restrict_to_largest_connected_set(self, connectivity='post_hoc_RE', connecti
288288
Only needed if connectivity="post_hoc_RE" or "BAR_variance". Values greater than 1.0 weaken the connectivity
289289
conditions. For 'post_hoc_RE' this multiplies the number of hypothetically observed transitions. For
290290
'BAR_variance' this scales the threshold for the minimal allowed variance of free energy differences.
291-
progress : object, default=None
292-
Progress bar that TRAMDataset will call to indicate progress to the user.
293-
Tested for a tqdm progress bar. Should implement `update()` and `close()`.
291+
progress : object
292+
Progress bar object that `TRAMDataset` will call to indicate progress to the user.
293+
Tested for a tqdm progress bar. The interface is checked
294+
via :meth:`supports_progress_interface <deeptime.util.callbacks.supports_progress_interface>`.
294295
295296
Raises
296297
------
@@ -416,8 +417,8 @@ def _find_largest_connected_set(self, connectivity, connectivity_factor, progres
416417
else:
417418
connectivity_fn = tram.find_state_transitions_BAR_variance
418419

419-
with callbacks.Callback(progress, self.n_therm_states * self.n_markov_states,
420-
"Finding connected sets") as callback:
420+
with callbacks.ProgressCallback(progress, "Finding connected sets",
421+
self.n_therm_states * self.n_markov_states) as callback:
421422
(i_s, j_s) = connectivity_fn(self.ttrajs, self.dtrajs, self.bias_matrices, all_state_counts,
422423
self.n_therm_states, self.n_markov_states, connectivity_factor,
423424
callback)

deeptime/src/include/deeptime/common.h

+8-7
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ struct ComputeIndex {
4646
static constexpr auto compute(const Arr &strides, const std::tuple<Ix...> &tup, std::index_sequence<I...>) {
4747
return (0 + ... + (strides[I] * std::get<I>(tup)));
4848
}
49+
50+
template<typename Arr, typename Arr2, std::size_t... I>
51+
static constexpr auto computeContainer(const Arr &strides, const Arr2 &tup, std::index_sequence<I...>) {
52+
return (0 + ... + (strides[I] * std::get<I>(tup)));
53+
}
4954
};
5055

5156
}
@@ -152,13 +157,9 @@ class Index {
152157
* @param indices the Dims-dimensional index
153158
* @return the 1D index
154159
*/
155-
template<typename Arr>
156-
value_type index(const Arr &indices) const {
157-
std::size_t result{0};
158-
for (std::size_t i = 0; i < Dims; ++i) {
159-
result += _cum_size[i] * indices[i];
160-
}
161-
return result;
160+
template<typename Arr, typename Indices = std::make_index_sequence<Dims>>
161+
constexpr value_type index(const Arr &indices) const {
162+
return detail::ComputeIndex<>::computeContainer(_cum_size, indices, Indices{});
162163
}
163164

164165
/**

deeptime/util/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,18 @@
5151
parallel.handle_n_jobs
5252
decorators.cached_property
5353
decorators.plotting_function
54+
55+
callbacks.supports_progress_interface
56+
callbacks.ProgressCallback
57+
58+
platform.module_available
59+
platform.handle_progress_bar
5460
"""
5561

5662
from .stats import QuantityStatistics, confidence_interval
5763
from ._validation import LaggedModelValidator
5864

5965
from . import data
6066
from . import types
67+
from . import callbacks
68+
from . import platform

deeptime/util/callbacks.py

+49-19
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,58 @@
1-
import copy
21
from .platform import handle_progress_bar
32

43

5-
class Callback:
6-
"""Base callback function for the c++ bindings to indicate progress by incrementing a progress bar.
4+
def supports_progress_interface(bar):
5+
r""" Method to check if a progress bar supports the deeptime interface, meaning that it
6+
has `update`, `close`, and `set_description` methods as well as a `total` attribute.
77
8-
Parameters
9-
----------
10-
progress_bar : object
11-
Tested for a tqdm progress bar. Should implement update() and close() and have .total and .desc properties.
12-
n_iter : int
13-
Number of iterations to completion.
14-
display_text : string
15-
text to display in front of the progress bar.
16-
"""
8+
Parameters
9+
----------
10+
bar : object, optional
11+
The progress bar implementation to check, can be None.
1712
18-
def __init__(self, progress, n_iter=None, display_text=None):
19-
self.progress_bar = handle_progress_bar(progress)()
20-
if display_text is not None:
21-
self.progress_bar.desc = display_text
22-
if n_iter is not None:
23-
self.progress_bar.total = n_iter
13+
Returns
14+
-------
15+
supports : bool
16+
Whether the progress bar is supported.
2417
25-
def __call__(self):
18+
See Also
19+
--------
20+
ProgressCallback
21+
"""
22+
has_methods = all(callable(getattr(bar, method, None)) for method in supports_progress_interface.required_methods)
23+
return has_methods
24+
25+
26+
supports_progress_interface.required_methods = ['update', 'close', 'set_description']
27+
28+
29+
class ProgressCallback:
30+
r"""Base callback function for the c++ bindings to indicate progress by incrementing a progress bar.
31+
32+
Parameters
33+
----------
34+
progress : object
35+
Tested for a tqdm progress bar. Should implement `update()`, `set_description()`, and `close()`. Should
36+
also possess a `total` constructor keyword argument.
37+
total : int
38+
Number of iterations to completion.
39+
description : string
40+
text to display in front of the progress bar.
41+
42+
See Also
43+
--------
44+
supports_progress_interface
45+
"""
46+
47+
def __init__(self, progress, description=None, total=None):
48+
self.progress_bar = handle_progress_bar(progress)(total=total)
49+
assert supports_progress_interface(self.progress_bar), \
50+
f"Progress bar did not satisfy interface! It should at least have " \
51+
f"the method(s) {supports_progress_interface.required_methods}."
52+
if description is not None:
53+
self.progress_bar.set_description(description)
54+
55+
def __call__(self, *args, **kw):
2656
self.progress_bar.update()
2757

2858
def __enter__(self):

deeptime/util/platform.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,17 @@ def handle_progress_bar(progress):
3636
class progress:
3737
def __init__(self, x=None, **_):
3838
self._x = x
39+
self.total = None
3940

40-
def __enter__(self):
41-
return self
42-
43-
def __exit__(self, exc_type, exc_val, exc_tb):
44-
return False
41+
def __enter__(self): return self
42+
def __exit__(self, exc_type, exc_val, exc_tb): return False
4543

4644
def __iter__(self):
4745
for x in self._x:
4846
yield x
4947

5048
def update(self): pass
51-
5249
def close(self): pass
50+
def set_description(self, *_): pass
5351

5452
return progress

docs/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ memory_profiler
1414
mdshare
1515
nbconvert
1616
jupyter
17+
tqdm

docs/source/index_msm.rst

+11-1
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,24 @@ over the encountered state transitions. This is covered in `transition counting
5454
notebooks/mlmsm
5555
notebooks/pcca
5656
notebooks/tpt
57-
notebooks/tram
5857

5958
Furthermore, deeptime implements :class:`Augmented Markov models <deeptime.markov.msm.AugmentedMSMEstimator>`
6059
:footcite:`olsson2017combining` which can be used when experimental data is available, as well as
6160
:class:`Observable Operator Model MSMs <deeptime.markov.msm.OOMReweightedMSM>` :footcite:`nuske2017markov` which is
6261
an unbiased estimator for the MSM transition matrix that corrects for the effect of starting out of equilibrium,
6362
even when short lag times are used.
6463

64+
.. rubric:: Multiensemble MSMs
65+
66+
Deeptime offers the TRAM method :footcite:`wu2016multiensemble` for estimating multiensemble MSMs. These are collections
67+
of MSMs based on simulations that are governed by biased dynamics (i.e., replica exchange simulations
68+
and umbrella sampling).
69+
70+
.. toctree::
71+
:maxdepth: 1
72+
73+
notebooks/tram
74+
6575
.. rubric:: References
6676

6777
.. footbibliography::

docs/source/notebooks

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(SRC bindings.cpp)
2+
pybind11_add_module(custom_metric ${SRC})
3+
target_link_libraries(custom_metric PUBLIC deeptime::deeptime)
4+
if(OpenMP_FOUND)
5+
target_link_libraries(custom_metric PUBLIC OpenMP::OpenMP_CXX)
6+
endif()

examples/clustering_custom_metric/bindings.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "register_clustering.h"
1+
#include <deeptime/clustering/register_clustering.h>
22

33
struct MaximumMetric {
44

0 commit comments

Comments
 (0)