@@ -265,12 +265,24 @@ def test_tqdm_progress_bar():
265
265
tram .fit (make_random_input_data (5 , 5 ))
266
266
267
267
268
- def test_fit_with_dataset ():
268
+ @pytest .mark .parametrize (
269
+ "init_strategy" , ["MBAR" , None ]
270
+ )
271
+ def test_fit_with_dataset (init_strategy ):
269
272
dataset = TRAMDataset (dtrajs = [np .asarray ([0 , 1 , 2 ])], bias_matrices = [np .asarray ([[1. ], [2. ], [3. ]])])
270
- tram = TRAM ()
273
+ tram = TRAM (init_strategy = init_strategy )
271
274
tram .fit (dataset )
272
275
273
276
277
+ @pytest .mark .parametrize (
278
+ "init_strategy" , ["MBAR" , None ]
279
+ )
280
+ def test_fit_with_dataset (init_strategy ):
281
+ input_data = make_random_input_data (20 , 2 )
282
+ tram = TRAM (init_strategy = init_strategy )
283
+ tram .fit (input_data )
284
+
285
+
274
286
def test_mbar_initalization ():
275
287
(dtrajs , bias_matrices ) = make_random_input_data (5 , 5 , make_ttrajs = False )
276
288
tram = TRAM (callback_interval = 2 , maxiter = 0 , progress = tqdm , init_maxiter = 100 )
@@ -296,3 +308,4 @@ def test_mbar_initialization_zero_iterations():
296
308
model1 = tram1 .fit_fetch (input_data )
297
309
model2 = tram2 .fit_fetch (input_data )
298
310
np .testing .assert_equal (model1 .biased_conf_energies , model2 .biased_conf_energies )
311
+
0 commit comments