Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add enforce_likelihood_threshold to FlowProposal #452

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
60 changes: 52 additions & 8 deletions src/nessai/proposal/flowproposal/flowproposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,18 @@
samples. This is translated to a value for ``fuzz``.
truncate_log_q : bool, optional
Truncate proposals using minimum log-probability of the training data.
enforce_likelihood_threshold : bool
If True, enforce the likelihood threshold when performing rejection
sampling. If false, the likelihood is not checked when populating the
proposal.
"""

def __init__(
self,
model,
poolsize=None,
latent_prior="truncated_gaussian",
latent_temperature=None,
constant_volume_mode=True,
volume_fraction=0.95,
fuzz=1.0,
Expand All @@ -89,6 +94,7 @@
min_radius=False,
max_radius=50.0,
compute_radius_with_all=False,
enforce_likelihood_threshold=False,
**kwargs,
):
super().__init__(
Expand All @@ -106,11 +112,13 @@
fuzz,
expansion_fraction,
latent_prior,
latent_temperature,
)

self.truncate_log_q = truncate_log_q
self.constant_volume_mode = constant_volume_mode
self.volume_fraction = volume_fraction
self.enforce_likelihood_threshold = enforce_likelihood_threshold

self.compute_radius_with_all = compute_radius_with_all
self.configure_fixed_radius(fixed_radius)
Expand All @@ -125,6 +133,7 @@
fuzz,
expansion_fraction,
latent_prior,
latent_temperature=None,
):
"""
Configure settings related to population
Expand All @@ -136,6 +145,14 @@
self.fuzz = fuzz
self.expansion_fraction = expansion_fraction
self.latent_prior = latent_prior
if latent_temperature is not None:
if latent_prior != "gaussian":
raise ValueError(
"Latent temperature can only be used with a Gaussian latent prior"
)
else:
logger.warning("`latent_temperature` is experimental!")
self.latent_temperature = latent_temperature

def configure_latent_prior(self):
"""Configure the latent prior"""
Expand Down Expand Up @@ -294,14 +311,19 @@
elif self.latent_prior == "flow":
self._draw_func = lambda N: self.flow.sample_latent_distribution(N)
else:
assert self.rng is not None
self._draw_func = partial(
self._draw_latent_prior,
draw_kwargs = dict(
dims=self.dims,
r=self.r,
fuzz=self.fuzz,
rng=self.rng,
)
if self.latent_temperature is not None:
draw_kwargs["temperature"] = self.latent_temperature
assert self.rng is not None
self._draw_func = partial(
self._draw_latent_prior,
**draw_kwargs,
)

def draw_latent_prior(self, n):
"""Draw n samples from the latent prior."""
Expand Down Expand Up @@ -432,7 +454,11 @@
min_log_q = None

logger.debug(f"Populating proposal with latent radius: {r:.5}")
self.r = r
# Radius is not used for the flow or Gaussian latent priors
if self.latent_prior in ["flow", "gaussian"]:
self.r = np.nan

Check warning on line 459 in src/nessai/proposal/flowproposal/flowproposal.py

View check run for this annotation

Codecov / codecov/patch

src/nessai/proposal/flowproposal/flowproposal.py#L459

Added line #L459 was not covered by tests
else:
self.r = r

self.alt_dist = self.get_alt_distribution()

Expand All @@ -459,6 +485,7 @@
log_constant = -np.inf
n_accepted = 0
accept = None
likelihood_threshold = worst_point["logL"]

while n_accepted < n_samples:
z = self.draw_latent_prior(self.drawsize)
Expand All @@ -476,8 +503,24 @@
self.drawsize - above_min_log_q.sum(),
)
x, log_q = get_subset_arrays(above_min_log_q, x, log_q)
# Handle case where all samples are below min_log_q

if self.enforce_likelihood_threshold:
x["logL"] = self.model.batch_evaluate_log_likelihood(

Check warning on line 508 in src/nessai/proposal/flowproposal/flowproposal.py

View check run for this annotation

Codecov / codecov/patch

src/nessai/proposal/flowproposal/flowproposal.py#L508

Added line #L508 was not covered by tests
x, unit_hypercube=self.map_to_unit_hypercube
)
above_threshold = x["logL"] > likelihood_threshold
x, log_q = get_subset_arrays(above_threshold, x, log_q)
logger.debug(

Check warning on line 513 in src/nessai/proposal/flowproposal/flowproposal.py

View check run for this annotation

Codecov / codecov/patch

src/nessai/proposal/flowproposal/flowproposal.py#L511-L513

Added lines #L511 - L513 were not covered by tests
"Accepting %s / %s samples above logL threshold",
len(x),
self.drawsize,
)
# Handle case where all samples have been discarded
if not len(x):
logger.warning(

Check warning on line 520 in src/nessai/proposal/flowproposal/flowproposal.py

View check run for this annotation

Codecov / codecov/patch

src/nessai/proposal/flowproposal/flowproposal.py#L520

Added line #L520 was not covered by tests
"All samples were discard before performing rejection "
"sampling."
)
continue
log_w = self.compute_weights(x, log_q)

Expand Down Expand Up @@ -532,9 +575,10 @@

self.population_time += datetime.datetime.now() - st
logger.debug("Evaluating log-likelihoods")
self.samples["logL"] = self.model.batch_evaluate_log_likelihood(
self.samples
)
if not self.enforce_likelihood_threshold:
self.samples["logL"] = self.model.batch_evaluate_log_likelihood(
self.samples
)
if self.check_acceptance:
self.acceptance.append(
self.compute_acceptance(worst_point["logL"])
Expand Down
4 changes: 2 additions & 2 deletions src/nessai/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def draw_uniform(dims, r=(1,), N=1000, fuzz=1.0, rng=None):
return rng.random((N, dims))


def draw_gaussian(dims, r=1, N=1000, fuzz=1.0, rng=None):
def draw_gaussian(dims, r=1, N=1000, fuzz=1.0, rng=None, temperature=1):
"""
Wrapper for numpy.random.randn that deals with extra input parameters
r and fuzz
Expand All @@ -145,7 +145,7 @@ def draw_gaussian(dims, r=1, N=1000, fuzz=1.0, rng=None):
if rng is None:
logger.debug("No rng specified, using the default rng.")
rng = np.random.default_rng()
return rng.standard_normal((N, dims))
return np.sqrt(temperature) * rng.standard_normal((N, dims))


def draw_truncated_gaussian(dims, r, N=1000, fuzz=1.0, var=1, rng=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def proposal(rng):
proposal = create_autospec(FlowProposal)
proposal._initialised = False
proposal.accumulate_weights = False
proposal.enforce_likelihood_threshold = False
proposal.map_to_unit_hypercube = False
proposal.rng = rng
return proposal
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,25 @@ def test_config_drawsize_none(proposal):
assert proposal.drawsize == 2000


def test_configure_latent_temperature(proposal):
"""Test the configuration of the latent temperature"""
FlowProposal.configure_population(
proposal, 1000, 1.0, 0.0, "gaussian", 0.9
)
assert proposal.latent_temperature == 0.9


def test_configure_latent_temperature_invalid(proposal):
"""Test the configuration of the latent temperature"""
with pytest.raises(
ValueError,
match="Latent temperature can only be used with a Gaussian latent",
):
FlowProposal.configure_population(
proposal, 1000, 1.0, 0.0, "truncated_gaussian", 0.9
)


@pytest.mark.parametrize("fixed_radius", [False, 5.0, 1])
def test_config_fixed_radius(proposal, fixed_radius):
"""Test the configuration for a fixed radius"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@ def test_prep_latent_prior_truncated(proposal):
def test_prep_latent_prior_other(proposal):
"""Assert partial acts as expected"""
proposal.latent_prior = "gaussian"
proposal.latent_temperature = 0.9
proposal.dims = 2
proposal.r = 3.0
proposal.fuzz = 1.2

def draw(dims, N=None, r=None, fuzz=None, rng=None):
def draw(dims, N=None, r=None, fuzz=None, rng=None, temperature=None):
return np.zeros((N, dims))

proposal._draw_latent_prior = draw
Expand All @@ -119,7 +120,7 @@ def draw(dims, N=None, r=None, fuzz=None, rng=None):
FlowProposal.prep_latent_prior(proposal)

mock_partial.assert_called_once_with(
draw, dims=2, r=3.0, fuzz=1.2, rng=proposal.rng
draw, dims=2, r=3.0, fuzz=1.2, rng=proposal.rng, temperature=0.9
)

assert proposal._draw_func(N=10).shape == (10, 2)
Expand Down Expand Up @@ -237,6 +238,7 @@ def test_populate_accumulate_weights(
proposal.get_alt_distribution = MagicMock(return_value=None)
proposal.prep_latent_prior = MagicMock()
proposal.draw_latent_prior = MagicMock(side_effect=z)
proposal.latent_prior = "truncated_gaussian"
proposal.compute_weights = MagicMock(side_effect=log_w)
proposal.compute_acceptance = MagicMock(return_value=0.8)
proposal.model = MagicMock()
Expand Down Expand Up @@ -427,6 +429,7 @@ def test_populate_not_accumulate_weights(
proposal.backward_pass = MagicMock(side_effect=zip(x, log_q))
proposal.radius = MagicMock(return_value=r_flow)
proposal.get_alt_distribution = MagicMock(return_value=None)
proposal.latent_prior = "truncated_gaussian"
proposal.prep_latent_prior = MagicMock()
proposal.draw_latent_prior = MagicMock(side_effect=z)
proposal.compute_weights = MagicMock(side_effect=log_w)
Expand Down Expand Up @@ -595,6 +598,7 @@ def test_populate_truncate_log_q(proposal, rng):
names=names,
)
proposal.accumulate_weights = True
proposal.latent_prior = "truncated_gaussian"

log_q_live = np.zeros(nlive)
log_q_live[-1] = -1.0
Expand Down