@@ -19,6 +19,18 @@ def _determine_n_therm_states(dtrajs, ttrajs):
19
19
return _determine_n_states (ttrajs )
20
20
21
21
22
+ def _split_at_negative_state_indices (trajectory_fragment , negative_state_indices ):
23
+ split_fragments = np .split (trajectory_fragment , negative_state_indices )
24
+ sub_fragments = []
25
+ # now get rid of the negative state indices.
26
+ for frag in split_fragments :
27
+ frag = frag [frag >= 0 ]
28
+ # Only add to the list if there are any samples left in the fragments
29
+ if len (frag ) > 0 :
30
+ sub_fragments .append (frag )
31
+ return sub_fragments
32
+
33
+
22
34
def transition_counts_from_count_models (n_therm_states , n_markov_states , count_models ):
23
35
transition_counts = np .zeros ((n_therm_states , n_markov_states , n_markov_states ), dtype = np .int32 )
24
36
@@ -454,13 +466,24 @@ def _find_trajectory_fragments(self):
454
466
# get a mapping from trajectory segments to thermodynamic states
455
467
fragment_indices = self ._find_trajectory_fragment_mapping ()
456
468
457
- fragments = []
469
+ fragments = [[] for _ in range ( self . n_therm_states ) ]
458
470
# for each them. state k, gather all trajectory fragments that were sampled at that state.
459
471
for k in range (self .n_therm_states ):
460
- # take the fragments based on the list of indices. Exclude all values that are less than zero. They don't
461
- # belong in the connected set.
462
- fragments .append ([self .dtrajs [traj_idx ][start :stop ][self .dtrajs [traj_idx ][start :stop ] >= 0 ]
463
- for (traj_idx , start , stop ) in fragment_indices [k ]])
472
+ # Select the fragments using the list of indices.
473
+ for (traj_idx , start , stop ) in fragment_indices [k ]:
474
+ fragment = self .dtrajs [traj_idx ][start :stop ]
475
+
476
+ # Whenever state values are negative, those samples do not belong in the connected set and need to be
477
+ # excluded. We split trajectories where negative state indices occur.
478
+ # Example: [0, 0, 2, -1, 2, 1, 0], we want to exclude the sample with state index -1.
479
+ # Simply filtering out negative state indices would lead to [0, 0, 2, 2, 1, 0] which gives a transition
480
+ # 2 -> 2 which doesn't exist. Instead, split the trajectory at negative state indices to get
481
+ # [0, 0, 2], [2, 1, 0]
482
+ negative_state_indices = np .where (fragment < 0 )[0 ]
483
+ if len (negative_state_indices ) > 0 :
484
+ fragments [k ].extend (_split_at_negative_state_indices (fragment , negative_state_indices ))
485
+ else :
486
+ fragments [k ].append (fragment )
464
487
return fragments
465
488
466
489
def _find_trajectory_fragment_mapping (self ):
0 commit comments