Skip to content

Commit 5a785b2

Browse files
authored
TRAM fix artificial transition counts by negative state indices (#194)
1 parent 0b76177 commit 5a785b2

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

deeptime/markov/msm/tram/_tram_dataset.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,18 @@ def _determine_n_therm_states(dtrajs, ttrajs):
1919
return _determine_n_states(ttrajs)
2020

2121

22+
def _split_at_negative_state_indices(trajectory_fragment, negative_state_indices):
23+
split_fragments = np.split(trajectory_fragment, negative_state_indices)
24+
sub_fragments = []
25+
# now get rid of the negative state indices.
26+
for frag in split_fragments:
27+
frag = frag[frag >= 0]
28+
# Only add to the list if there are any samples left in the fragments
29+
if len(frag) > 0:
30+
sub_fragments.append(frag)
31+
return sub_fragments
32+
33+
2234
def transition_counts_from_count_models(n_therm_states, n_markov_states, count_models):
2335
transition_counts = np.zeros((n_therm_states, n_markov_states, n_markov_states), dtype=np.int32)
2436

@@ -454,13 +466,24 @@ def _find_trajectory_fragments(self):
454466
# get a mapping from trajectory segments to thermodynamic states
455467
fragment_indices = self._find_trajectory_fragment_mapping()
456468

457-
fragments = []
469+
fragments = [[] for _ in range(self.n_therm_states)]
458470
# for each them. state k, gather all trajectory fragments that were sampled at that state.
459471
for k in range(self.n_therm_states):
460-
# take the fragments based on the list of indices. Exclude all values that are less than zero. They don't
461-
# belong in the connected set.
462-
fragments.append([self.dtrajs[traj_idx][start:stop][self.dtrajs[traj_idx][start:stop] >= 0]
463-
for (traj_idx, start, stop) in fragment_indices[k]])
472+
# Select the fragments using the list of indices.
473+
for (traj_idx, start, stop) in fragment_indices[k]:
474+
fragment = self.dtrajs[traj_idx][start:stop]
475+
476+
# Whenever state values are negative, those samples do not belong in the connected set and need to be
477+
# excluded. We split trajectories where negative state indices occur.
478+
# Example: [0, 0, 2, -1, 2, 1, 0], we want to exclude the sample with state index -1.
479+
# Simply filtering out negative state indices would lead to [0, 0, 2, 2, 1, 0] which gives a transition
480+
# 2 -> 2 which doesn't exist. Instead, split the trajectory at negative state indices to get
481+
# [0, 0, 2], [2, 1, 0]
482+
negative_state_indices = np.where(fragment < 0)[0]
483+
if len(negative_state_indices) > 0:
484+
fragments[k].extend(_split_at_negative_state_indices(fragment, negative_state_indices))
485+
else:
486+
fragments[k].append(fragment)
464487
return fragments
465488

466489
def _find_trajectory_fragment_mapping(self):

tests/markov/msm/test_tram_datatset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,10 @@ def test_get_trajectory_fragments(dtrajs, ttrajs):
227227
bias_matrices = make_matching_bias_matrix(dtrajs)
228228
dataset = TRAMDataset(dtrajs=dtrajs, ttrajs=ttrajs, bias_matrices=bias_matrices)
229229

230-
# dtraj should be split into fragments [[[1, 3], [5, 6], [8, 9]], [[10, 11], [12, 13]]] due to replica exchanges
230+
# dtraj should be split into fragments [[[1], [3], [5, 6], [8, 9]], [[10, 11], [12, 13]]] due to replica exchanges
231231
# found in ttrajs. This should lead having only 5 transitions in transition counts:
232232
np.testing.assert_equal(dataset.state_counts.sum(), 10)
233-
np.testing.assert_equal(dataset.transition_counts.sum(), 5)
233+
np.testing.assert_equal(dataset.transition_counts.sum(), 4)
234234

235235

236236
def test_unknown_connectivity():

0 commit comments

Comments
 (0)