Skip to content

Commit 0b76177

Browse files
authored
Small TRAM refactorings (#192)
1 parent c413dca commit 0b76177

File tree

14 files changed

+444
-437
lines changed

14 files changed

+444
-437
lines changed

deeptime/markov/msm/tram/_bindings/src/tram_module.cpp

+13-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include "deeptime/markov/msm/tram/tram.h"
33
#include "deeptime/markov/msm/tram/connected_set.h"
44
#include "deeptime/markov/msm/tram/trajectory_mapping.h"
5-
5+
66
PYBIND11_MODULE(_tram_bindings, m) {
77
using namespace pybind11::literals;
88
using namespace deeptime::markov::tram;
@@ -12,8 +12,8 @@ PYBIND11_MODULE(_tram_bindings, m) {
1212
py::class_<TRAM<double>>(tramMod, "TRAM")
1313
.def(py::init<std::size_t, std::size_t>(), "n_therm_states"_a, "n_markov_states"_a)
1414
.def(py::init<deeptime::np_array_nfc<double> &,
15-
deeptime::np_array_nfc<double> &, deeptime::np_array_nfc<double> &>(),
16-
"biased_conf_energies"_a, "lagrangian_mult_log"_a, "modified_state_counts_log"_a)
15+
deeptime::np_array_nfc<double> &, deeptime::np_array_nfc<double> &>(),
16+
"biased_conf_energies"_a, "lagrangian_mult_log"_a, "modified_state_counts_log"_a)
1717
.def("estimate", &TRAM<double>::estimate,
1818
"input"_a, "max_iter"_a = 1000, "max_err"_a = 1e-8, "callback_interval"_a = 1,
1919
"track_log_likelihoods"_a = false, "callback"_a = nullptr)
@@ -22,17 +22,15 @@ PYBIND11_MODULE(_tram_bindings, m) {
2222
.def_property_readonly("modified_state_counts_log", &TRAM<double>::modifiedStateCountsLog)
2323
.def_property_readonly("lagrangian_mult_log", &TRAM<double>::lagrangianMultLog)
2424
.def_property_readonly("therm_state_energies", &TRAM<double>::thermStateEnergies)
25-
.def_property_readonly("markov_state_energies", &TRAM<double>::markovStateEnergies)
26-
.def("compute_log_likelihood", &TRAM<double>::computeLogLikelihood, py::call_guard<py::gil_scoped_release>());
27-
25+
.def_property_readonly("markov_state_energies", &TRAM<double>::markovStateEnergies);
2826

2927
py::class_<TRAMInput<double>, std::shared_ptr<TRAMInput<double>>>(tramMod, "TRAMInput").def(
30-
py::init<deeptime::np_array_nfc<int> &&, deeptime::np_array_nfc<int> &&, DTrajs, BiasMatrices<double>>(),
31-
"state_counts"_a, "transition_counts"_a, "dtrajs"_a, "bias_matrices"_a);
28+
py::init<deeptime::np_array_nfc<int> &&, deeptime::np_array_nfc<int> &&, DTraj, BiasMatrix<double>>(),
29+
"state_counts"_a, "transition_counts"_a, "dtraj"_a, "bias_matrix"_a);
3230

33-
tramMod.def("compute_sample_weights", &computeSampleWeights<double>, py::call_guard<py::gil_scoped_release>(),
34-
"therm_state_index"_a = -1, "dtrajs"_a, "bias_matrices"_a, "therm_state_energies"_a,
35-
"modified_state_counts_log"_a);
31+
tramMod.def("compute_sample_weights_log", &computeSampleWeightsLog<double>, py::call_guard<py::gil_scoped_release>(),
32+
"dtraj"_a, "bias_matrix"_a, "therm_state_energies"_a,
33+
"modified_state_counts_log"_a, "therm_state_index"_a = -1);
3634

3735
tramMod.def("find_state_transitions_post_hoc_RE",
3836
&findStateTransitions<double, OverlapPostHocReplicaExchange<double>>,
@@ -46,5 +44,9 @@ PYBIND11_MODULE(_tram_bindings, m) {
4644
"connectivity_factor"_a, "callback"_a);
4745

4846
tramMod.def("find_trajectory_fragment_indices", &findTrajectoryFragmentIndices, "ttrajs"_a, "n_therm_states"_a);
47+
48+
tramMod.def("compute_log_likelihood", computeLogLikelihood<double>, py::call_guard<py::gil_scoped_release>(),
49+
"dtraj"_a, "biasMatrix"_a, "biasedConfEnergies"_a, "modifiedStateCountsLog"_a,
50+
"thermStateEnergies"_a, "stateCounts"_a, "transitionCounts"_a, "transitionMatrices"_a);
4951
}
5052
}

deeptime/markov/msm/tram/_tram.py

-25
Original file line numberDiff line numberDiff line change
@@ -119,31 +119,6 @@ def __init__(
119119
self.log_likelihoods = []
120120
self.increments = []
121121

122-
@property
123-
def compute_log_likelihood(self) -> Optional[float]:
124-
r"""The parameter-dependent part of the TRAM likelihood.
125-
126-
The definition can be found in :footcite:`wu2016multiensemble`, Equation (9).
127-
128-
Returns
129-
-------
130-
log_likelihood : float
131-
The parameter-dependent part of the log-likelihood.
132-
133-
134-
Notes
135-
-----
136-
Parameter-dependent, i.e., the factor
137-
138-
.. math:: \prod_{x \in X} e^{-b^{k(x)}(x)}
139-
140-
does not occur in the log-likelihood as it is constant with respect to the parameters, leading to
141-
142-
.. math:: \log \prod_{k=1}^K \left(\prod_{i,j} (p_{ij}^k)^{c_{ij}^k}\right) \left(\prod_{i} \prod_{x \in X_i^k} \mu(x) e^{f_i^k} \right)
143-
"""
144-
if self._tram_estimator is not None:
145-
return self._tram_estimator.compute_log_likelihood()
146-
147122
def fetch_model(self) -> Optional[TRAMModel]:
148123
r"""Yields the most recent :class:`MarkovStateModelCollection` that was estimated.
149124
Can be None if fit was not called.

deeptime/markov/msm/tram/_tram_dataset.py

+34-25
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
from deeptime.util import types, callbacks
66
from deeptime.util.decorators import cached_property
7-
87
from deeptime.markov import TransitionCountEstimator, TransitionCountModel
9-
108
from ._tram_bindings import tram
119

1210

@@ -21,6 +19,30 @@ def _determine_n_therm_states(dtrajs, ttrajs):
2119
return _determine_n_states(ttrajs)
2220

2321

22+
def transition_counts_from_count_models(n_therm_states, n_markov_states, count_models):
23+
transition_counts = np.zeros((n_therm_states, n_markov_states, n_markov_states), dtype=np.int32)
24+
25+
for k in range(n_therm_states):
26+
model_k = count_models[k]
27+
if model_k.count_matrix.sum() > 0:
28+
i_s, j_s = np.meshgrid(model_k.state_symbols, model_k.state_symbols, indexing='ij')
29+
# place submodel counts in our full-sized count matrices
30+
transition_counts[k, i_s, j_s] = model_k.count_matrix
31+
32+
return transition_counts
33+
34+
35+
def state_counts_from_count_models(n_therm_states, n_markov_states, count_models):
36+
state_counts = np.zeros((n_therm_states, n_markov_states), dtype=np.int32)
37+
38+
for k in range(n_therm_states):
39+
model_k = count_models[k]
40+
if model_k.count_matrix.sum() > 0:
41+
state_counts[k, model_k.state_symbols] = model_k.state_histogram
42+
43+
return state_counts
44+
45+
2446
def to_zero_padded_array(arrays, desired_shape):
2547
"""Pad a list of numpy arrays with zeros to desired shape. Desired shape should be at least the size of the
2648
largest np array in the list.
@@ -55,8 +77,8 @@ def _invalidate_caches():
5577
class TRAMDataset:
5678
r""" Dataset for organizing data and obtaining properties from data that are needed for TRAM.
5779
The minimum required parameters for constructing a TRAMDataset are the `dtrajs` and `bias_matrices`. In this case,
58-
`ttrajs` are inferred from the shape of the `dtrajs`, by assuming each trajectory in `dtrajs` corresponds to a unique
59-
thermodynamic state, with the index corresponding to the index of occurrence in `dtrajs`.
80+
`ttrajs` are inferred from the shape of the `dtrajs`, by assuming each trajectory in `dtrajs` corresponds to a
81+
unique thermodynamic state, with the index corresponding to the index of occurrence in `dtrajs`.
6082
6183
The values at identical indices in `dtrajs`, `ttrajs` and `bias_matrices` correspond to the sample. For example, at
6284
indices `(i, n)` we find information about the :math:`n`-th sample in trajectory :math:`i`. `dtrajs[i][n]` gives us
@@ -141,8 +163,11 @@ def __init__(self, dtrajs, bias_matrices, ttrajs=None, n_therm_states=None, n_ma
141163

142164
@property
143165
def tram_input(self):
144-
r""" The TRAMInput object containing the data needed for estimation. """
145-
return tram.TRAMInput(self.state_counts, self.transition_counts, self.dtrajs, self.bias_matrices)
166+
r""" The TRAMInput object containing the data needed for estimation.
167+
For estimation purposes, it does not matter which thermodynamic state each sample was sampled at. The dtrajs and
168+
bias_matrices are therefore flattened along the first dimension, to speed up estimation. """
169+
return tram.TRAMInput(self.state_counts, self.transition_counts,
170+
np.concatenate(self.dtrajs), np.concatenate(self.bias_matrices))
146171

147172
@property
148173
def n_therm_states(self):
@@ -169,16 +194,7 @@ def transition_counts(self):
169194
:getter: the transition counts
170195
:type: ndarray(n, m, m)
171196
"""
172-
transition_counts = np.zeros((self.n_therm_states, self.n_markov_states, self.n_markov_states), dtype=np.int32)
173-
174-
for k in range(self.n_therm_states):
175-
model_k = self.count_models[k]
176-
if model_k.count_matrix.sum() > 0:
177-
i_s, j_s = np.meshgrid(model_k.state_symbols, model_k.state_symbols)
178-
# place submodel counts in our full-sized count matrices
179-
transition_counts[k, i_s, j_s] = model_k.count_matrix.T
180-
181-
return transition_counts
197+
return transition_counts_from_count_models(self.n_therm_states, self.n_markov_states, self.count_models)
182198

183199
@cached_property
184200
def state_counts(self):
@@ -192,14 +208,7 @@ def state_counts(self):
192208
matrices that are all the same shape, which is easier to handle (matrices are padded with zeros for all empty
193209
states that got dropped by the TransitionCountModels).
194210
"""
195-
state_counts = np.zeros((self.n_therm_states, self.n_markov_states), dtype=np.int32)
196-
197-
for k in range(self.n_therm_states):
198-
model_k = self.count_models[k]
199-
if model_k.count_matrix.sum() > 0:
200-
state_counts[k, model_k.state_symbols] = model_k.state_histogram
201-
202-
return state_counts
211+
return state_counts_from_count_models(self.n_therm_states, self.n_markov_states, self.count_models)
203212

204213
def check_against_model(self, model):
205214
r""" Check the number of thermodynamic states of the model against that of the dataset. The number of
@@ -385,7 +394,7 @@ def _find_largest_connected_set(self, connectivity, connectivity_factor, progres
385394
all_state_counts = np.asarray([estimator.fit_fetch(dtraj).state_histogram for dtraj in self.dtrajs],
386395
dtype=object)
387396
# pad with zero's so they are all the same size and easier for the cpp module to handle
388-
all_state_counts = to_zero_padded_array(all_state_counts, self.n_markov_states)
397+
all_state_counts = to_zero_padded_array(all_state_counts, self.n_markov_states).astype(np.int32)
389398

390399
# get list of all possible transitions between thermodynamic states. A transition is only possible when two
391400
# thermodynamic states have an overlapping markov state. Whether the markov state overlaps depends on the

deeptime/markov/msm/tram/_tram_model.py

+74-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from deeptime.numeric import logsumexp
55
from deeptime.markov.msm import MarkovStateModelCollection
66

7+
from ._tram_dataset import transition_counts_from_count_models, state_counts_from_count_models
78
from ._tram_bindings import tram
89

910

@@ -51,7 +52,7 @@ def __init__(self, count_models, transition_matrices,
5152
lagrangian_mult_log,
5253
modified_state_counts_log,
5354
therm_state_energies=None,
54-
markov_state_energies=None,
55+
markov_state_energies=None
5556
):
5657
self.n_therm_states = biased_conf_energies.shape[0]
5758
self.n_markov_states = biased_conf_energies.shape[1]
@@ -66,6 +67,9 @@ def __init__(self, count_models, transition_matrices,
6667
else:
6768
self._therm_state_energies = therm_state_energies
6869

70+
self._transition_matrices = transition_matrices
71+
self._count_models = count_models
72+
6973
self._msm_collection = self._construct_msm_collection(
7074
count_models, transition_matrices)
7175

@@ -145,8 +149,15 @@ def compute_sample_weights(self, dtrajs, bias_matrices, therm_state=-1):
145149
146150
.. math:: \mu(x) = \left( \sum_k R^k_{i(x)} \mathrm{exp}[f^k_{i(k)}-b^k(x)] \right)^{-1}
147151
"""
148-
return tram.compute_sample_weights(therm_state, dtrajs, bias_matrices, self._therm_state_energies,
149-
self._modified_state_counts_log)
152+
# flatten input data
153+
dtraj = np.concatenate(dtrajs)
154+
bias_matrix = np.concatenate(bias_matrices)
155+
156+
sample_weights = self._compute_sample_weights(dtraj, bias_matrix, therm_state)
157+
158+
# return in the original list shape
159+
traj_start_stops = np.concatenate(([0], np.cumsum([len(traj) for traj in dtrajs])))
160+
return [sample_weights[traj_start_stops[i - 1]:traj_start_stops[i]] for i in range(1, len(traj_start_stops))]
150161

151162
def compute_observable(self, observable_values, dtrajs, bias_matrices, therm_state=-1):
152163
r""" Compute an observable value.
@@ -169,11 +180,11 @@ def compute_observable(self, observable_values, dtrajs, bias_matrices, therm_sta
169180
The index of the thermodynamic state in which the observable need to be computed. If `therm_state=-1`, the
170181
observable is computed for the unbiased (reference) state.
171182
"""
172-
sample_weights = self.compute_sample_weights(dtrajs, bias_matrices, therm_state)
183+
# flatten input data
184+
observable_values = np.concatenate(observable_values)
173185

174-
# flatten both
175-
sample_weights = np.reshape(sample_weights, -1)
176-
observable_values = np.reshape(observable_values, -1)
186+
sample_weights = self._compute_sample_weights(np.concatenate(dtrajs), np.concatenate(bias_matrices),
187+
therm_state)
177188

178189
return np.dot(sample_weights, observable_values)
179190

@@ -200,20 +211,68 @@ def compute_PMF(self, dtrajs, bias_matrices, bin_indices, therm_state=-1):
200211
computed for the unbiased (reference) state.
201212
"""
202213
# TODO: account for variable bin widths
203-
sample_weights = np.reshape(self.compute_sample_weights(dtrajs, bias_matrices, therm_state), -1)
204-
binned_samples = np.reshape(bin_indices, -1)
214+
sample_weights = self._compute_sample_weights(np.concatenate(dtrajs), np.concatenate(bias_matrices),
215+
therm_state)
216+
217+
binned_samples = np.concatenate(bin_indices)
205218

206219
n_bins = binned_samples.max() + 1
207220
pmf = np.zeros(n_bins)
208221

209222
for i in range(len(pmf)):
210223
indices = np.where(binned_samples == i)
211-
pmf[i] = -np.log(np.sum(sample_weights[indices]))
224+
if len(indices[0]) > 0:
225+
pmf[i] = -np.log(np.sum(sample_weights[indices]))
212226

213227
# shift minimum to zero
214228
pmf -= pmf.min()
215229
return pmf
216230

231+
def compute_log_likelihood(self, dtrajs, bias_matrices):
232+
r"""The (parameter-dependent part of the) likelihood to observe the given data.
233+
234+
The definition can be found in :footcite:`wu2016multiensemble`, Equation (9).
235+
236+
Parameters
237+
----------
238+
dtrajs : list(np.ndarray)
239+
The list of discrete trajectories. `dtrajs[i][n]` contains the Markov state index of the :math:`n`-th sample
240+
in the :math:`i`-th trajectory.
241+
bias_matrices : list(np.ndarray)
242+
The bias energy matrices. `bias_matrices[i][n, k]` contains the bias energy of the :math:`n`-th sample from
243+
the :math:`i`-th trajectory, evaluated at thermodynamic state :math:`k`, :math:`b^k(x_{i,n})`. The bias
244+
energy matrices should have the same size as `dtrajs` in both the first and second dimension. The third
245+
dimension is of size `n_therm_state`, i.e. for each sample, the bias energy in every thermodynamic state is
246+
calculated and stored in the `bias_matrices`.
247+
248+
Returns
249+
-------
250+
log_likelihood : float
251+
The parameter-dependent part of the log-likelihood.
252+
253+
254+
Notes
255+
-----
256+
Parameter-dependent, i.e., the factor
257+
258+
.. math:: \prod_{x \in X} e^{-b^{k(x)}(x)}
259+
260+
does not occur in the log-likelihood as it is constant with respect to the parameters, leading to
261+
262+
.. math:: \log \prod_{k=1}^K \left(\prod_{i,j} (p_{ij}^k)^{c_{ij}^k}\right) \left(\prod_{i} \prod_{x \in X_i^k} \mu(x) e^{f_i^k} \right)
263+
"""
264+
dtraj = np.concatenate(dtrajs)
265+
bias_matrix = np.concatenate(bias_matrices)
266+
267+
transition_counts = transition_counts_from_count_models(self.n_therm_states, self.n_markov_states,
268+
self._count_models)
269+
270+
state_counts = state_counts_from_count_models(self.n_therm_states, self.n_markov_states, self._count_models)
271+
272+
return tram.compute_log_likelihood(dtraj, bias_matrix, self._biased_conf_energies,
273+
self._modified_state_counts_log, self._therm_state_energies, state_counts,
274+
transition_counts, self._transition_matrices)
275+
217276
def _construct_msm_collection(self, count_models, transition_matrices):
218277
r""" Construct a MarkovStateModelCollection from the transition matrices and energy estimates.
219278
For each of the thermodynamic states, one MarkovStateModel is added to the MarkovStateModelCollection. The
@@ -237,3 +296,8 @@ def _construct_msm_collection(self, count_models, transition_matrices):
237296
return MarkovStateModelCollection(transition_matrices_connected, stationary_distributions,
238297
reversible=True, count_models=count_models,
239298
transition_matrix_tolerance=1e-8)
299+
300+
def _compute_sample_weights(self, dtraj, bias_matrix, therm_state=-1):
301+
sample_weights = tram.compute_sample_weights_log(dtraj, bias_matrix, self._therm_state_energies,
302+
self._modified_state_counts_log, therm_state)
303+
return np.exp(np.asarray(sample_weights))

0 commit comments

Comments
 (0)