Skip to content

Commit de8768d

Browse files
authored
fix tram bug (#209)
1 parent dc320db commit de8768d

File tree

5 files changed

+19
-5
lines changed

5 files changed

+19
-5
lines changed

deeptime/markov/msm/tram/_tram.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def _make_tram_estimator(self, model, dataset):
238238
# copy free energies along the markoc state axis to get initial biased_conf_energies
239239
biased_conf_energies = np.repeat(free_energies[:, None], dataset.n_markov_states, axis=1)
240240
else:
241-
biased_conf_energies = np.zeros((dataset.n_markov_states, dataset.n_therm_states))
241+
biased_conf_energies = np.zeros((dataset.n_therm_states, dataset.n_markov_states))
242242

243243
lagrangian_mult_log = tram.initialize_lagrangians(dataset.transition_counts)
244244
modified_state_counts = np.zeros_like(lagrangian_mult_log) # intialize this as the dataset state counts???

deeptime/src/include/deeptime/markov/msm/tram/mbar.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ initialize_MBAR(BiasMatrix <dtype> biasMatrix, CountsMatrix stateCounts, std::si
6363
double maxErr = 1e-6, std::size_t callbackInterval = 1, const py::object *callback = nullptr) {
6464
// get dimensions...
6565
auto nThermStates = stateCounts.shape(0);
66-
auto nSamples = biasMatrix.shape(1);
66+
auto nSamples = biasMatrix.shape(0);
6767

6868
// work in log space so compute the log of the statecounts beforehand
6969
auto stateCountsLog = std::vector<dtype>(nThermStates);

devtools/conda-setup+build.yml

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ steps:
55
conda config --set quiet true
66
displayName: Configure conda
77
- bash: |
8+
conda clean --all
89
conda install mamba
910
mamba update --all
1011
mamba install boa conda-build conda-verify pip

examples/methods/plot_tram.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def sample_trajectories(bias_functions):
7676
dtrajs = clustering.transform(trajectories.flatten()).reshape((len(bias_matrices), n_samples))
7777

7878
from tqdm import tqdm
79-
tram = TRAM(lagtime=1, maxiter=1000, maxerr=1e-3, progress=tqdm)
79+
tram = TRAM(lagtime=1, maxiter=1000, maxerr=1e-3, progress=tqdm, init_strategy="MBAR")
8080

8181
# For every simulation frame seen in trajectory i and time step t, btrajs[i][t,k] is the
8282
# bias energy of that frame evaluated in the k'th thermodynamic state (i.e. at the k'th

tests/markov/msm/test_tram.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,24 @@ def test_tqdm_progress_bar():
265265
tram.fit(make_random_input_data(5, 5))
266266

267267

268-
def test_fit_with_dataset():
268+
@pytest.mark.parametrize(
269+
"init_strategy", ["MBAR", None]
270+
)
271+
def test_fit_with_dataset(init_strategy):
269272
dataset = TRAMDataset(dtrajs=[np.asarray([0, 1, 2])], bias_matrices=[np.asarray([[1.], [2.], [3.]])])
270-
tram = TRAM()
273+
tram = TRAM(init_strategy=init_strategy)
271274
tram.fit(dataset)
272275

273276

277+
@pytest.mark.parametrize(
278+
"init_strategy", ["MBAR", None]
279+
)
280+
def test_fit_with_dataset(init_strategy):
281+
input_data = make_random_input_data(20, 2)
282+
tram = TRAM(init_strategy=init_strategy)
283+
tram.fit(input_data)
284+
285+
274286
def test_mbar_initalization():
275287
(dtrajs, bias_matrices) = make_random_input_data(5, 5, make_ttrajs=False)
276288
tram = TRAM(callback_interval=2, maxiter=0, progress=tqdm, init_maxiter=100)
@@ -296,3 +308,4 @@ def test_mbar_initialization_zero_iterations():
296308
model1 = tram1.fit_fetch(input_data)
297309
model2 = tram2.fit_fetch(input_data)
298310
np.testing.assert_equal(model1.biased_conf_energies, model2.biased_conf_energies)
311+

0 commit comments

Comments
 (0)