4
4
from deeptime .numeric import logsumexp
5
5
from deeptime .markov .msm import MarkovStateModelCollection
6
6
7
+ from ._tram_dataset import transition_counts_from_count_models , state_counts_from_count_models
7
8
from ._tram_bindings import tram
8
9
9
10
@@ -51,7 +52,7 @@ def __init__(self, count_models, transition_matrices,
51
52
lagrangian_mult_log ,
52
53
modified_state_counts_log ,
53
54
therm_state_energies = None ,
54
- markov_state_energies = None ,
55
+ markov_state_energies = None
55
56
):
56
57
self .n_therm_states = biased_conf_energies .shape [0 ]
57
58
self .n_markov_states = biased_conf_energies .shape [1 ]
@@ -66,6 +67,9 @@ def __init__(self, count_models, transition_matrices,
66
67
else :
67
68
self ._therm_state_energies = therm_state_energies
68
69
70
+ self ._transition_matrices = transition_matrices
71
+ self ._count_models = count_models
72
+
69
73
self ._msm_collection = self ._construct_msm_collection (
70
74
count_models , transition_matrices )
71
75
@@ -145,8 +149,15 @@ def compute_sample_weights(self, dtrajs, bias_matrices, therm_state=-1):
145
149
146
150
.. math:: \mu(x) = \left( \sum_k R^k_{i(x)} \mathrm{exp}[f^k_{i(k)}-b^k(x)] \right)^{-1}
147
151
"""
148
- return tram .compute_sample_weights (therm_state , dtrajs , bias_matrices , self ._therm_state_energies ,
149
- self ._modified_state_counts_log )
152
+ # flatten input data
153
+ dtraj = np .concatenate (dtrajs )
154
+ bias_matrix = np .concatenate (bias_matrices )
155
+
156
+ sample_weights = self ._compute_sample_weights (dtraj , bias_matrix , therm_state )
157
+
158
+ # return in the original list shape
159
+ traj_start_stops = np .concatenate (([0 ], np .cumsum ([len (traj ) for traj in dtrajs ])))
160
+ return [sample_weights [traj_start_stops [i - 1 ]:traj_start_stops [i ]] for i in range (1 , len (traj_start_stops ))]
150
161
151
162
def compute_observable (self , observable_values , dtrajs , bias_matrices , therm_state = - 1 ):
152
163
r""" Compute an observable value.
@@ -169,11 +180,11 @@ def compute_observable(self, observable_values, dtrajs, bias_matrices, therm_sta
169
180
The index of the thermodynamic state in which the observable need to be computed. If `therm_state=-1`, the
170
181
observable is computed for the unbiased (reference) state.
171
182
"""
172
- sample_weights = self .compute_sample_weights (dtrajs , bias_matrices , therm_state )
183
+ # flatten input data
184
+ observable_values = np .concatenate (observable_values )
173
185
174
- # flatten both
175
- sample_weights = np .reshape (sample_weights , - 1 )
176
- observable_values = np .reshape (observable_values , - 1 )
186
+ sample_weights = self ._compute_sample_weights (np .concatenate (dtrajs ), np .concatenate (bias_matrices ),
187
+ therm_state )
177
188
178
189
return np .dot (sample_weights , observable_values )
179
190
@@ -200,20 +211,68 @@ def compute_PMF(self, dtrajs, bias_matrices, bin_indices, therm_state=-1):
200
211
computed for the unbiased (reference) state.
201
212
"""
202
213
# TODO: account for variable bin widths
203
- sample_weights = np .reshape (self .compute_sample_weights (dtrajs , bias_matrices , therm_state ), - 1 )
204
- binned_samples = np .reshape (bin_indices , - 1 )
214
+ sample_weights = self ._compute_sample_weights (np .concatenate (dtrajs ), np .concatenate (bias_matrices ),
215
+ therm_state )
216
+
217
+ binned_samples = np .concatenate (bin_indices )
205
218
206
219
n_bins = binned_samples .max () + 1
207
220
pmf = np .zeros (n_bins )
208
221
209
222
for i in range (len (pmf )):
210
223
indices = np .where (binned_samples == i )
211
- pmf [i ] = - np .log (np .sum (sample_weights [indices ]))
224
+ if len (indices [0 ]) > 0 :
225
+ pmf [i ] = - np .log (np .sum (sample_weights [indices ]))
212
226
213
227
# shift minimum to zero
214
228
pmf -= pmf .min ()
215
229
return pmf
216
230
231
+ def compute_log_likelihood (self , dtrajs , bias_matrices ):
232
+ r"""The (parameter-dependent part of the) likelihood to observe the given data.
233
+
234
+ The definition can be found in :footcite:`wu2016multiensemble`, Equation (9).
235
+
236
+ Parameters
237
+ ----------
238
+ dtrajs : list(np.ndarray)
239
+ The list of discrete trajectories. `dtrajs[i][n]` contains the Markov state index of the :math:`n`-th sample
240
+ in the :math:`i`-th trajectory.
241
+ bias_matrices : list(np.ndarray)
242
+ The bias energy matrices. `bias_matrices[i][n, k]` contains the bias energy of the :math:`n`-th sample from
243
+ the :math:`i`-th trajectory, evaluated at thermodynamic state :math:`k`, :math:`b^k(x_{i,n})`. The bias
244
+ energy matrices should have the same size as `dtrajs` in both the first and second dimension. The third
245
+ dimension is of size `n_therm_state`, i.e. for each sample, the bias energy in every thermodynamic state is
246
+ calculated and stored in the `bias_matrices`.
247
+
248
+ Returns
249
+ -------
250
+ log_likelihood : float
251
+ The parameter-dependent part of the log-likelihood.
252
+
253
+
254
+ Notes
255
+ -----
256
+ Parameter-dependent, i.e., the factor
257
+
258
+ .. math:: \prod_{x \in X} e^{-b^{k(x)}(x)}
259
+
260
+ does not occur in the log-likelihood as it is constant with respect to the parameters, leading to
261
+
262
+ .. math:: \log \prod_{k=1}^K \left(\prod_{i,j} (p_{ij}^k)^{c_{ij}^k}\right) \left(\prod_{i} \prod_{x \in X_i^k} \mu(x) e^{f_i^k} \right)
263
+ """
264
+ dtraj = np .concatenate (dtrajs )
265
+ bias_matrix = np .concatenate (bias_matrices )
266
+
267
+ transition_counts = transition_counts_from_count_models (self .n_therm_states , self .n_markov_states ,
268
+ self ._count_models )
269
+
270
+ state_counts = state_counts_from_count_models (self .n_therm_states , self .n_markov_states , self ._count_models )
271
+
272
+ return tram .compute_log_likelihood (dtraj , bias_matrix , self ._biased_conf_energies ,
273
+ self ._modified_state_counts_log , self ._therm_state_energies , state_counts ,
274
+ transition_counts , self ._transition_matrices )
275
+
217
276
def _construct_msm_collection (self , count_models , transition_matrices ):
218
277
r""" Construct a MarkovStateModelCollection from the transition matrices and energy estimates.
219
278
For each of the thermodynamic states, one MarkovStateModel is added to the MarkovStateModelCollection. The
@@ -237,3 +296,8 @@ def _construct_msm_collection(self, count_models, transition_matrices):
237
296
return MarkovStateModelCollection (transition_matrices_connected , stationary_distributions ,
238
297
reversible = True , count_models = count_models ,
239
298
transition_matrix_tolerance = 1e-8 )
299
+
300
+ def _compute_sample_weights (self , dtraj , bias_matrix , therm_state = - 1 ):
301
+ sample_weights = tram .compute_sample_weights_log (dtraj , bias_matrix , self ._therm_state_energies ,
302
+ self ._modified_state_counts_log , therm_state )
303
+ return np .exp (np .asarray (sample_weights ))
0 commit comments