33
44from __future__ import annotations
55
6- from typing import Tuple
6+ from typing import Tuple , Union
77
88import pytest
99import torch
@@ -127,7 +127,7 @@ def simulator(theta):
127127 model = model , num_transforms = 2 , dtype = torch .float32
128128 )
129129 )
130- train_kwargs = dict (force_first_round_loss = True )
130+ train_kwargs = dict ()
131131 elif method == NLE :
132132 kwargs = dict (
133133 density_estimator = likelihood_nn (
@@ -152,9 +152,12 @@ def simulator(theta):
152152 x = simulator (theta ).to (data_device )
153153 theta = theta .to (data_device )
154154
155- estimator = inferer . append_simulations ( theta , x , data_device = data_device ). train (
156- training_batch_size = 100 , max_num_epochs = max_num_epochs , ** train_kwargs
155+ data_kwargs = (
156+ dict ( proposal = proposals [ - 1 ]) if method in [ NPE_A , NPE_C ] else dict ()
157157 )
158+ estimator = inferer .append_simulations (
159+ theta , x , data_device = data_device , ** data_kwargs
160+ ).train (max_num_epochs = max_num_epochs , ** train_kwargs )
158161
159162 # mcmc cases
160163 if sampling_method in ["slice_np" , "slice_np_vectorized" , "nuts_pymc" ]:
@@ -436,3 +439,29 @@ def test_boxuniform_device_handling(arg_device, device):
436439 low = zeros (1 ).to (arg_device ), high = ones (1 ).to (arg_device ), device = device
437440 )
438441 NPE_C (prior = prior , device = arg_device )
442+
443+
444+ @pytest .mark .gpu
445+ @pytest .mark .parametrize ("method" , [NPE_A , NPE_C ])
446+ @pytest .mark .parametrize ("device" , ["cpu" , "gpu" ])
447+ def test_multiround_mdn_training_on_device (method : Union [NPE_A , NPE_C ], device : str ):
448+ num_dim = 2
449+ num_rounds = 2
450+ num_simulations = 100
451+ device = process_device ("gpu" )
452+ prior = BoxUniform (- torch .ones (num_dim ), torch .ones (num_dim ), device = device )
453+ simulator = diagonal_linear_gaussian
454+
455+ estimator = "mdn_snpe_a" if method == NPE_A else "mdn"
456+
457+ trainer = method (prior , density_estimator = estimator , device = device )
458+
459+ theta = prior .sample ((num_simulations ,))
460+ x = simulator (theta )
461+
462+ proposal = prior
463+ for _ in range (num_rounds ):
464+ trainer .append_simulations (theta , x , proposal = proposal ).train (max_num_epochs = 2 )
465+ proposal = trainer .build_posterior ().set_default_x (torch .zeros (num_dim ))
466+ theta = proposal .sample ((num_simulations ,))
467+ x = simulator (theta )
0 commit comments