Skip to content

Commit 555f926

Browse files
andyl7anThe Meridian Authors
authored andcommitted
Optimize MCMC performance by decoupling sampling and reconstruction graphs
PiperOrigin-RevId: 866152466
1 parent 254216b commit 555f926

File tree

5 files changed

+750
-129
lines changed

5 files changed

+750
-129
lines changed

meridian/backend/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,7 @@ def _jax_adstock_process(
10721072
roll = _jax_roll
10731073
split = _jax_split
10741074
stack = _ops.stack
1075+
squeeze = _ops.squeeze
10751076
tile = _jax_tile
10761077
transpose = _jax_transpose
10771078
unique_with_counts = _jax_unique_with_counts
@@ -1254,6 +1255,7 @@ def _tf_adstock_process(
12541255
set_random_seed = tf_backend.keras.utils.set_random_seed
12551256
split = _ops.split
12561257
stack = _ops.stack
1258+
squeeze = _ops.squeeze
12571259
tile = _ops.tile
12581260
transpose = _ops.transpose
12591261
unique_with_counts = _tf_unique_with_counts
@@ -1379,7 +1381,10 @@ def __init__(self, seed: SeedType):
13791381
self._key: Optional["_jax.Array"] = None
13801382

13811383
if seed is None:
1382-
return
1384+
# Automatically generate a seed if none is provided, allowing JAX
1385+
# to function similarly to other backends where None is acceptable.
1386+
seed = np.random.randint(_MAX_INT32)
1387+
self._int_seed = seed
13831388

13841389
if (
13851390
isinstance(seed, jax.Array) # pylint: disable=undefined-variable

meridian/backend/backend_test.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,15 +1615,29 @@ def test_correct_class_is_exposed(self, backend_name):
16151615
self.assertIs(backend.RNGHandler, backend._TFRNGHandler)
16161616
# pylint: enable=protected-access
16171617

1618-
@parameterized.named_parameters(("tensorflow", _TF), ("jax", _JAX))
1619-
def test_initialization_with_none_seed_is_noop(self, backend_name):
1620-
"""Verifies that a None seed creates a handler that returns None."""
1618+
@parameterized.named_parameters(
1619+
dict(
1620+
testcase_name="tensorflow",
1621+
backend_name=_TF,
1622+
assert_fn_name="assertIsNone",
1623+
),
1624+
dict(
1625+
testcase_name="jax",
1626+
backend_name=_JAX,
1627+
assert_fn_name="assertIsNotNone",
1628+
),
1629+
)
1630+
def test_initialization_with_none_seed_is_noop(
1631+
self, backend_name, assert_fn_name
1632+
):
1633+
"""Verifies behavior when initialized with None."""
16211634
self._set_backend_for_test(backend_name)
16221635
handler = backend.RNGHandler(None)
1636+
assertion = getattr(self, assert_fn_name)
16231637

16241638
self.assertIsNone(handler._seed_input)
1625-
self.assertIsNone(handler.get_next_seed())
1626-
self.assertIsNone(handler.get_kernel_seed())
1639+
assertion(handler.get_next_seed())
1640+
assertion(handler.get_kernel_seed())
16271641

16281642
@parameterized.named_parameters(("tensorflow", _TF), ("jax", _JAX))
16291643
def test_initialization_with_integer_seed(self, backend_name):
@@ -1770,17 +1784,29 @@ def test_get_next_seed_is_reproducible(self, backend_name):
17701784
else:
17711785
test_utils.assert_allequal(s1, s2)
17721786

1773-
@parameterized.named_parameters(("tensorflow", _TF), ("jax", _JAX))
1774-
def test_advance_handler_with_none_seed(self, backend_name):
1775-
"""Tests that advancing a no-op handler produces another no-op handler."""
1787+
@parameterized.named_parameters(
1788+
dict(
1789+
testcase_name="tensorflow",
1790+
backend_name=_TF,
1791+
assert_fn_name="assertIsNone",
1792+
),
1793+
dict(
1794+
testcase_name="jax",
1795+
backend_name=_JAX,
1796+
assert_fn_name="assertIsNotNone",
1797+
),
1798+
)
1799+
def test_advance_handler_with_none_seed(self, backend_name, assert_fn_name):
1800+
"""Tests advancing a handler initialized with None."""
17761801
self._set_backend_for_test(backend_name)
17771802
handler = backend.RNGHandler(None)
17781803
new_handler = handler.advance_handler()
1804+
assertion = getattr(self, assert_fn_name)
17791805

17801806
self.assertIsNot(handler, new_handler)
1781-
self.assertIsNone(new_handler._seed_input)
1782-
self.assertIsNone(handler.get_next_seed())
1783-
self.assertIsNone(new_handler.get_kernel_seed())
1807+
assertion(new_handler._seed_input)
1808+
assertion(handler.get_next_seed())
1809+
assertion(new_handler.get_kernel_seed())
17841810

17851811
@parameterized.named_parameters(("tensorflow", _TF), ("jax", _JAX))
17861812
def test_advance_handler_provides_independent_handlers(self, backend_name):

meridian/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@
474474
GAMMA_GC_DEV,
475475
GAMMA_GN_DEV,
476476
TAU_G_EXCL_BASELINE, # Used to derive TAU_G.
477+
'y',
477478
)
478479
IGNORED_PRIORS_MEDIA = immutabledict.immutabledict({
479480
TREATMENT_PRIOR_TYPE_ROI: (

0 commit comments

Comments
 (0)