Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions meridian/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,13 +815,13 @@
BIC = 'bic'
EBIC = 'ebic'

# Posterior downsampling constants.
POSTERIOR_IS_DOWNSAMPLED = 'posterior_is_downsampled'
POSTERIOR_DOWNSAMPLE_METHOD = 'posterior_downsample_method'
POSTERIOR_DOWNSAMPLE_SAMPLING_RATE = 'posterior_downsample_sampling_rate'
# Posterior thinning constants.
POSTERIOR_IS_THINNED = 'posterior_is_thinned'
POSTERIOR_THINNING_METHOD = 'posterior_thinning_method'
POSTERIOR_THINNING_SAMPLING_RATE = 'posterior_thinning_sampling_rate'
POSTERIOR_ORIGINAL_CHAIN_COUNT = 'posterior_original_chain_count'
POSTERIOR_ORIGINAL_DRAW_COUNT = 'posterior_original_draw_count'
POSTERIOR_SELECTED_DRAW_COUNT_PER_CHAIN = (
'posterior_selected_draw_count_per_chain'
)
POSTERIOR_DOWNSAMPLE_SEED = 'posterior_downsample_seed'
POSTERIOR_THINNING_SEED = 'posterior_thinning_seed'
46 changes: 23 additions & 23 deletions meridian/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@


@enum.unique
class DownsampleMethod(enum.Enum):
"""Posterior draw downsampling methods."""
class ThinningMethod(enum.Enum):
"""Posterior draw thinning methods."""

SYSTEMATIC = "systematic"

Expand Down Expand Up @@ -1219,15 +1219,15 @@ def sample_posterior_and_review(
)
self.review()

def downsample_posterior(
def posterior_thinning(
self,
sampling_rate: float | None = None,
n_draws: int | None = None,
method: DownsampleMethod = DownsampleMethod.SYSTEMATIC,
method: ThinningMethod = ThinningMethod.SYSTEMATIC,
seed: int | Sequence[int] | None = None,
preserve_original: bool = True,
) -> xr.Dataset:
"""Downsamples `inference_data.posterior` while preserving chains.
"""Thins `inference_data.posterior` while preserving chains.

This method replaces `self.inference_data.posterior` with a chain-preserving
subset of posterior draws. For example, a posterior with shape
Expand All @@ -1236,7 +1236,7 @@ def downsample_posterior(

The main use case is accelerating posterior workflows such as budget
optimization while continuing to use Meridian's existing APIs unchanged.
Outputs produced after downsampling are approximate with respect to the full
Outputs produced after thinning are approximate with respect to the full
posterior.

Systematic sampling selects posterior samples from each MCMC chain at a
Expand All @@ -1250,33 +1250,33 @@ def downsample_posterior(
original_n_draws]`. Exactly one of `sampling_rate` or `n_draws` must be
provided.
method: Draw selection method. Currently only
`DownsampleMethod.SYSTEMATIC` is supported.
`ThinningMethod.SYSTEMATIC` is supported.
seed: Optional random seed for reproducible draw selection. This is used
only for selecting posterior draw indices.
preserve_original: If `True`, stores a copy of the full posterior on this
model so `restore_full_posterior()` can restore it.

Returns:
The downsampled posterior `xarray.Dataset`.
The thinned posterior `xarray.Dataset`.

Raises:
NotFittedModelError: If the model does not have posterior samples.
ValueError: If arguments are invalid.
"""
if not hasattr(self.inference_data, constants.POSTERIOR):
raise NotFittedModelError(
"sample_posterior() must be called before downsample_posterior()."
"sample_posterior() must be called before posterior_thinning()."
)
if (sampling_rate is None) == (n_draws is None):
raise ValueError(
"Exactly one of `sampling_rate` or `n_draws` must be provided."
)
if method is not DownsampleMethod.SYSTEMATIC:
raise ValueError(f"Unsupported posterior downsample method: {method}.")
if method is not ThinningMethod.SYSTEMATIC:
raise ValueError(f"Unsupported posterior thinning method: {method}.")

posterior = self.inference_data.posterior
if posterior.attrs.get(constants.POSTERIOR_IS_DOWNSAMPLED):
raise ValueError("Posterior has already been downsampled.")
if posterior.attrs.get(constants.POSTERIOR_IS_THINNED):
raise ValueError("Posterior has already been thinned.")
if (
constants.CHAIN not in posterior.sizes
or constants.DRAW not in posterior.sizes
Expand Down Expand Up @@ -1315,14 +1315,14 @@ def downsample_posterior(
selected_draw_indices,
dims=(constants.CHAIN, constants.DRAW),
)
downsampled_posterior = posterior.isel(
thinned_posterior = posterior.isel(
{constants.DRAW: draw_indexer}
).assign_coords({constants.DRAW: np.arange(n_selected_draws)})
attrs = dict(posterior.attrs)
attrs.update({
constants.POSTERIOR_IS_DOWNSAMPLED: True,
constants.POSTERIOR_DOWNSAMPLE_METHOD: method.value,
constants.POSTERIOR_DOWNSAMPLE_SAMPLING_RATE: (
constants.POSTERIOR_IS_THINNED: True,
constants.POSTERIOR_THINNING_METHOD: method.value,
constants.POSTERIOR_THINNING_SAMPLING_RATE: (
float(sampling_rate)
if sampling_rate is not None
else n_selected_draws / n_original_draws
Expand All @@ -1332,21 +1332,21 @@ def downsample_posterior(
constants.POSTERIOR_SELECTED_DRAW_COUNT_PER_CHAIN: n_selected_draws,
})
if seed is not None:
attrs[constants.POSTERIOR_DOWNSAMPLE_SEED] = (
attrs[constants.POSTERIOR_THINNING_SEED] = (
list(seed)
if isinstance(seed, Sequence) and not isinstance(seed, (str, bytes))
else int(seed)
)
downsampled_posterior.attrs = attrs
self.inference_data.posterior = downsampled_posterior
return downsampled_posterior
thinned_posterior.attrs = attrs
self.inference_data.posterior = thinned_posterior
return thinned_posterior

def restore_full_posterior(self) -> xr.Dataset:
"""Restores the full posterior saved by `downsample_posterior()`."""
"""Restores the full posterior saved by `posterior_thinning()`."""
if self._full_posterior is None:
raise ValueError(
"No preserved full posterior is available. Call "
"downsample_posterior(..., preserve_original=True) first."
"posterior_thinning(..., preserve_original=True) first."
)
self.inference_data.posterior = self._full_posterior
self._full_posterior = None
Expand Down
76 changes: 38 additions & 38 deletions meridian/model/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ def _meridian_with_posterior(self, posterior: xr.Dataset) -> model.Meridian:
meridian.inference_data.posterior = posterior
return meridian

def test_downsample_posterior_preserves_chains(self):
def test_posterior_thinning_preserves_chains(self):
values = np.arange(3 * 10 * 2).reshape((3, 10, 2))
meridian = self._meridian_with_posterior(xr.Dataset(
data_vars={
Expand All @@ -744,26 +744,26 @@ def test_downsample_posterior_preserves_chains(self):
},
))

downsampled = meridian.downsample_posterior(n_draws=4, seed=7)
thinned = meridian.posterior_thinning(n_draws=4, seed=7)

self.assertEqual(downsampled.sizes[constants.CHAIN], 3)
self.assertEqual(downsampled.sizes[constants.DRAW], 4)
self.assertEqual(downsampled.sizes["channel"], 2)
self.assertEqual(thinned.sizes[constants.CHAIN], 3)
self.assertEqual(thinned.sizes[constants.DRAW], 4)
self.assertEqual(thinned.sizes["channel"], 2)
self.assertEqual(
downsampled.attrs["posterior_selected_draw_count_per_chain"], 4
thinned.attrs["posterior_selected_draw_count_per_chain"], 4
)
self.assertTrue(downsampled.attrs["posterior_is_downsampled"])
self.assertTrue(thinned.attrs["posterior_is_thinned"])
self.assertEqual(
downsampled.attrs["posterior_downsample_method"], "systematic"
thinned.attrs["posterior_thinning_method"], "systematic"
)
for chain in range(3):
with self.subTest(chain=chain):
selected_draws = downsampled["draw_id"].sel(chain=chain).values
selected_draws = thinned["draw_id"].sel(chain=chain).values
self.assertLen(set(selected_draws.tolist()), 4)
self.assertTrue(np.all(np.diff(selected_draws) >= 1))
self.assertTrue(np.all(selected_draws >= 0))
self.assertTrue(np.all(selected_draws < 10))
selected_values = downsampled["param"].sel(chain=chain).values[:, 0]
selected_values = thinned["param"].sel(chain=chain).values[:, 0]
self.assertTrue(np.all(selected_values >= chain * 20))
self.assertTrue(np.all(selected_values < (chain + 1) * 20))

Expand All @@ -773,18 +773,18 @@ def test_downsample_posterior_preserves_chains(self):
self.assertEqual(restored.sizes[constants.DRAW], 10)
np.testing.assert_array_equal(restored["param"].values, values)

def test_downsample_posterior_accepts_downsample_method_enum(self):
def test_posterior_thinning_accepts_thinning_method_enum(self):
meridian = self._meridian_with_posterior(_simple_posterior())

downsampled = meridian.downsample_posterior(
n_draws=4, method=model.DownsampleMethod.SYSTEMATIC, seed=7
thinned = meridian.posterior_thinning(
n_draws=4, method=model.ThinningMethod.SYSTEMATIC, seed=7
)

self.assertEqual(
downsampled.attrs["posterior_downsample_method"], "systematic"
thinned.attrs["posterior_thinning_method"], "systematic"
)

def test_downsample_posterior_supports_non_leading_draw_dimension(self):
def test_posterior_thinning_supports_non_leading_draw_dimension(self):
values = np.arange(2 * 3 * 10).reshape((2, 3, 10))
meridian = self._meridian_with_posterior(xr.Dataset(
data_vars={
Expand All @@ -800,14 +800,14 @@ def test_downsample_posterior_supports_non_leading_draw_dimension(self):
},
))

downsampled = meridian.downsample_posterior(n_draws=4, seed=7)
thinned = meridian.posterior_thinning(n_draws=4, seed=7)

self.assertEqual(
downsampled["param"].dims, ("channel", constants.CHAIN, constants.DRAW)
thinned["param"].dims, ("channel", constants.CHAIN, constants.DRAW)
)
self.assertEqual(downsampled.sizes[constants.CHAIN], 3)
self.assertEqual(downsampled.sizes[constants.DRAW], 4)
self.assertEqual(downsampled.sizes["channel"], 2)
self.assertEqual(thinned.sizes[constants.CHAIN], 3)
self.assertEqual(thinned.sizes[constants.DRAW], 4)
self.assertEqual(thinned.sizes["channel"], 2)

@parameterized.named_parameters(
dict(
Expand All @@ -833,7 +833,7 @@ def test_systematic_draw_indices_returns_exact_count(
self.assertTrue(np.all(selected >= 0))
self.assertTrue(np.all(selected < n_original_draws))

def test_downsample_posterior_seed_reproducible(self):
def test_posterior_thinning_seed_reproducible(self):
values = np.arange(2 * 30).reshape((2, 30))
first_meridian = self._meridian_with_posterior(xr.Dataset(
data_vars={
Expand All @@ -851,63 +851,63 @@ def test_downsample_posterior_seed_reproducible(self):
first_meridian.inference_data.posterior.copy(deep=True)
)

first = first_meridian.downsample_posterior(n_draws=5, seed=7)
second = second_meridian.downsample_posterior(n_draws=5, seed=7)
first = first_meridian.posterior_thinning(n_draws=5, seed=7)
second = second_meridian.posterior_thinning(n_draws=5, seed=7)

self.assertEqual(first.attrs["posterior_downsample_seed"], 7)
self.assertEqual(first.attrs["posterior_thinning_seed"], 7)
np.testing.assert_array_equal(first["param"].values, second["param"].values)

def test_downsample_posterior_requires_posterior(self):
def test_posterior_thinning_requires_posterior(self):
meridian = model.Meridian(input_data=self.input_data_with_media_only)

with self.assertRaises(model.NotFittedModelError):
meridian.downsample_posterior(sampling_rate=0.1)
meridian.posterior_thinning(sampling_rate=0.1)

@parameterized.named_parameters(
dict(testcase_name="missing", kwargs={}),
dict(testcase_name="both", kwargs={"sampling_rate": 0.1, "n_draws": 2}),
)
def test_downsample_posterior_requires_exactly_one_draw_argument(
def test_posterior_thinning_requires_exactly_one_draw_argument(
self, kwargs
):
meridian = self._meridian_with_posterior(_simple_posterior())

with self.assertRaisesRegex(ValueError, "Exactly one"):
meridian.downsample_posterior(**kwargs)
meridian.posterior_thinning(**kwargs)

@parameterized.named_parameters(
dict(testcase_name="zero", kwargs={"n_draws": 0}),
dict(testcase_name="too_many", kwargs={"n_draws": 11}),
)
def test_downsample_posterior_rejects_invalid_n_draws(self, kwargs):
def test_posterior_thinning_rejects_invalid_n_draws(self, kwargs):
meridian = self._meridian_with_posterior(_simple_posterior())

with self.assertRaisesRegex(ValueError, "`n_draws`"):
meridian.downsample_posterior(**kwargs)
meridian.posterior_thinning(**kwargs)

@parameterized.named_parameters(
dict(testcase_name="zero", kwargs={"sampling_rate": 0}),
dict(testcase_name="too_large", kwargs={"sampling_rate": 1.1}),
)
def test_downsample_posterior_rejects_invalid_sampling_rate(self, kwargs):
def test_posterior_thinning_rejects_invalid_sampling_rate(self, kwargs):
meridian = self._meridian_with_posterior(_simple_posterior())

with self.assertRaisesRegex(ValueError, "`sampling_rate`"):
meridian.downsample_posterior(**kwargs)
meridian.posterior_thinning(**kwargs)

def test_downsample_posterior_rejects_invalid_method(self):
def test_posterior_thinning_rejects_invalid_method(self):
meridian = self._meridian_with_posterior(_simple_posterior())

with self.assertRaisesRegex(ValueError, "Unsupported"):
meridian.downsample_posterior(n_draws=4, method=mock.MagicMock())
meridian.posterior_thinning(n_draws=4, method=mock.MagicMock())

def test_downsample_posterior_rejects_downsampling_twice(self):
def test_posterior_thinning_rejects_thinning_twice(self):
meridian = self._meridian_with_posterior(_simple_posterior())

meridian.downsample_posterior(n_draws=4, seed=7)
meridian.posterior_thinning(n_draws=4, seed=7)

with self.assertRaisesRegex(ValueError, "already been downsampled"):
meridian.downsample_posterior(n_draws=2)
with self.assertRaisesRegex(ValueError, "already been thinned"):
meridian.posterior_thinning(n_draws=2)


class ModelPersistenceTest(
Expand Down
Loading