|
21 | 21 | NPE, |
22 | 22 | NPE_A, |
23 | 23 | NPE_C, |
24 | | - NPSE, |
25 | 24 | NRE_A, |
26 | 25 | NRE_B, |
27 | 26 | NRE_C, |
|
38 | 37 | from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential |
39 | 38 | from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential |
40 | 39 | from sbi.inference.potentials.ratio_based_potential import RatioBasedPotential |
41 | | -from sbi.inference.trainers.npse.vector_field_inference import VectorFieldInference |
42 | 40 | from sbi.neural_nets.embedding_nets import FCEmbedding |
43 | 41 | from sbi.neural_nets.factory import ( |
44 | 42 | classifier_nn, |
@@ -625,7 +623,12 @@ def test_direct_posterior_on_gpu(device: str, device_inference: str): |
625 | 623 | ], |
626 | 624 | ) |
627 | 625 | def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotential]): |
628 | | - """Test to method on potential""" |
| 626 | + """Test .to() method on potential. |
| 627 | +
|
| 628 | + Args: |
| 629 | + device: device where to move the model. |
| 630 | + potential: potential to train the model on. |
| 631 | + """ |
629 | 632 |
|
630 | 633 | device = process_device(device) |
631 | 634 | prior = BoxUniform(torch.tensor([1.0]), torch.tensor([1.0])) |
@@ -665,7 +668,12 @@ def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotentia |
665 | 668 | "sampling_method", ["rejection", "importance", "mcmc", "direct"] |
666 | 669 | ) |
667 | 670 | def test_to_method_on_posteriors(device: str, sampling_method: str): |
668 | | - """Test that the .to() method works on posteriors.""" |
| 671 | + """Test .to() method on posteriors. |
| 672 | +
|
| 673 | + Args: |
| 674 | + device: device to train and sample the model on. |
| 675 | + sampling_method: method to sample from the posterior. |
| 676 | + """ |
669 | 677 | device = process_device(device) |
670 | 678 | prior = BoxUniform(torch.zeros(3), torch.ones(3)) |
671 | 679 | inference = NPE() |
@@ -708,25 +716,44 @@ def test_to_method_on_posteriors(device: str, sampling_method: str): |
708 | 716 |
|
709 | 717 | @pytest.mark.gpu |
710 | 718 | @pytest.mark.parametrize("device", ["cpu", "gpu"]) |
711 | | -@pytest.mark.parametrize("iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss"]) |
712 | | -@pytest.mark.parametrize("inference_method", [FMPE, NPSE]) |
713 | | -def test_VectorFieldPosterior( |
714 | | - device: str, iid_method: str, inference_method: VectorFieldInference |
| 719 | +@pytest.mark.parametrize("device_inference", ["cpu", "gpu"]) |
| 720 | +@pytest.mark.parametrize( |
| 721 | + "iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss", None] |
| 722 | +) |
| 723 | +def test_VectorFieldPosterior_device_handling( |
| 724 | + device: str, device_inference: str, iid_method: str |
715 | 725 | ): |
| 726 | + """Test VectorFieldPosterior on different devices training and inference devices. |
| 727 | +
|
| 728 | + Args: |
| 729 | + device: device to train the model on. |
| 730 | + device_inference: device to run the inference on. |
| 731 | + iid_method: method to sample from the posterior. |
| 732 | + """ |
716 | 733 | device = process_device(device) |
717 | | - prior = BoxUniform(torch.zeros(3), torch.ones(3), device="cpu") |
718 | | - inference = inference_method(score_estimator="mlp", prior=prior) |
| 734 | + device_inference = process_device(device_inference) |
| 735 | + prior = BoxUniform(torch.zeros(3), torch.ones(3), device=device) |
| 736 | + inference = FMPE(score_estimator="mlp", prior=prior, device=device) |
719 | 737 | density_estimator = inference.append_simulations( |
720 | 738 | torch.randn((100, 3)), torch.randn((100, 2)) |
721 | 739 | ).train(max_num_epochs=1) |
722 | 740 | posterior = inference.build_posterior(density_estimator, prior) |
723 | | - posterior.to(device) |
724 | | - assert posterior.device == device, ( |
725 | | - f"VectorFieldPosterior is not in device {device}." |
| 741 | + |
| 742 | + # faster but inaccurate log_prob computation |
| 743 | + posterior.potential_fn.neural_ode.update_params(exact=False, atol=1e-4, rtol=1e-4) |
| 744 | + |
| 745 | + posterior.to(device_inference) |
| 746 | + assert posterior.device == device_inference, ( |
| 747 | + f"VectorFieldPosterior is not in device {device_inference}." |
726 | 748 | ) |
727 | 749 |
|
728 | | - x_o = torch.ones(2).to(device) |
| 750 | + x_o = torch.ones(2).to(device_inference) |
729 | 751 | samples = posterior.sample((2,), x=x_o, iid_method=iid_method) |
730 | | - assert samples.device.type == device.split(":")[0], ( |
731 | | - f"Samples are not on device {device}." |
| 752 | + assert samples.device.type == device_inference.split(":")[0], ( |
| 753 | + f"Samples are not on device {device_inference}." |
| 754 | + ) |
| 755 | + |
| 756 | + log_probs = posterior.log_prob(samples, x=x_o) |
| 757 | + assert log_probs.device.type == device_inference.split(":")[0], ( |
| 758 | + f"log_prob was not correctly moved to {device_inference}." |
732 | 759 | ) |
0 commit comments