Skip to content

Commit 2db5860

Browse files
jorobledojanfb
andauthored
tests: gpu test for VectorFieldPosterior (#1542)
* improve test for VectorFieldPosterior and add docstrings to GPU tests * fix _loss setting all tensors to the correct device in VectorFieldInference * Apply suggestions from code review Co-authored-by: Jan <janfb@users.noreply.github.com> * fix linting --------- Co-authored-by: Jan <janfb@users.noreply.github.com>
1 parent fb5124e commit 2db5860

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

sbi/inference/trainers/npse/vector_field_inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,10 @@ def _loss(
585585
Returns:
586586
Calibration kernel-weighted loss implemented by the vector field estimator.
587587
"""
588+
589+
if times is not None:
590+
times = times.to(self._device)
591+
588592
cls_name = self.__class__.__name__
589593
if self._round == 0 or force_first_round_loss:
590594
# First round loss.

tests/inference_on_device_test.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
NPE,
2222
NPE_A,
2323
NPE_C,
24-
NPSE,
2524
NRE_A,
2625
NRE_B,
2726
NRE_C,
@@ -38,7 +37,6 @@
3837
from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential
3938
from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential
4039
from sbi.inference.potentials.ratio_based_potential import RatioBasedPotential
41-
from sbi.inference.trainers.npse.vector_field_inference import VectorFieldInference
4240
from sbi.neural_nets.embedding_nets import FCEmbedding
4341
from sbi.neural_nets.factory import (
4442
classifier_nn,
@@ -625,7 +623,12 @@ def test_direct_posterior_on_gpu(device: str, device_inference: str):
625623
],
626624
)
627625
def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotential]):
628-
"""Test to method on potential"""
626+
"""Test .to() method on potential.
627+
628+
Args:
629+
device: device where to move the model.
630+
potential: potential to train the model on.
631+
"""
629632

630633
device = process_device(device)
631634
prior = BoxUniform(torch.tensor([1.0]), torch.tensor([1.0]))
@@ -665,7 +668,12 @@ def test_to_method_on_potentials(device: str, potential: Union[ABC, BasePotentia
665668
"sampling_method", ["rejection", "importance", "mcmc", "direct"]
666669
)
667670
def test_to_method_on_posteriors(device: str, sampling_method: str):
668-
"""Test that the .to() method works on posteriors."""
671+
"""Test .to() method on posteriors.
672+
673+
Args:
674+
device: device to train and sample the model on.
675+
sampling_method: method to sample from the posterior.
676+
"""
669677
device = process_device(device)
670678
prior = BoxUniform(torch.zeros(3), torch.ones(3))
671679
inference = NPE()
@@ -708,25 +716,44 @@ def test_to_method_on_posteriors(device: str, sampling_method: str):
708716

709717
@pytest.mark.gpu
710718
@pytest.mark.parametrize("device", ["cpu", "gpu"])
711-
@pytest.mark.parametrize("iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss"])
712-
@pytest.mark.parametrize("inference_method", [FMPE, NPSE])
713-
def test_VectorFieldPosterior(
714-
device: str, iid_method: str, inference_method: VectorFieldInference
719+
@pytest.mark.parametrize("device_inference", ["cpu", "gpu"])
720+
@pytest.mark.parametrize(
721+
"iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss", None]
722+
)
723+
def test_VectorFieldPosterior_device_handling(
724+
device: str, device_inference: str, iid_method: str
715725
):
726+
"""Test VectorFieldPosterior on different devices training and inference devices.
727+
728+
Args:
729+
device: device to train the model on.
730+
device_inference: device to run the inference on.
731+
iid_method: method to sample from the posterior.
732+
"""
716733
device = process_device(device)
717-
prior = BoxUniform(torch.zeros(3), torch.ones(3), device="cpu")
718-
inference = inference_method(score_estimator="mlp", prior=prior)
734+
device_inference = process_device(device_inference)
735+
prior = BoxUniform(torch.zeros(3), torch.ones(3), device=device)
736+
inference = FMPE(score_estimator="mlp", prior=prior, device=device)
719737
density_estimator = inference.append_simulations(
720738
torch.randn((100, 3)), torch.randn((100, 2))
721739
).train(max_num_epochs=1)
722740
posterior = inference.build_posterior(density_estimator, prior)
723-
posterior.to(device)
724-
assert posterior.device == device, (
725-
f"VectorFieldPosterior is not in device {device}."
741+
742+
# faster but inaccurate log_prob computation
743+
posterior.potential_fn.neural_ode.update_params(exact=False, atol=1e-4, rtol=1e-4)
744+
745+
posterior.to(device_inference)
746+
assert posterior.device == device_inference, (
747+
f"VectorFieldPosterior is not in device {device_inference}."
726748
)
727749

728-
x_o = torch.ones(2).to(device)
750+
x_o = torch.ones(2).to(device_inference)
729751
samples = posterior.sample((2,), x=x_o, iid_method=iid_method)
730-
assert samples.device.type == device.split(":")[0], (
731-
f"Samples are not on device {device}."
752+
assert samples.device.type == device_inference.split(":")[0], (
753+
f"Samples are not on device {device_inference}."
754+
)
755+
756+
log_probs = posterior.log_prob(samples, x=x_o)
757+
assert log_probs.device.type == device_inference.split(":")[0], (
758+
f"log_prob was not correctly moved to {device_inference}."
732759
)

0 commit comments

Comments
 (0)