Skip to content

Commit e88bf3d

Browse files
authored
fix(gh-2181): ensure consistent behavior of Poisson.log_prob by casting rate to floating-point type (#2182)
* fix(gh-2181): ensure consistent behavior of `Poisson.log_prob` by casting rate to floating-point type * test: unit test that covers the fix * fix: remove float values from test cases
1 parent 517f3c1 commit e88bf3d

2 files changed

Lines changed: 45 additions & 9 deletions

File tree

numpyro/distributions/discrete.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -783,32 +783,33 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik
783783

784784
@validate_sample
785785
def log_prob(self, value: ArrayLike) -> ArrayLike:
786-
if self._validate_args:
787-
self._validate_sample(value)
786+
# Using an integer vs. floating-point `rate` leads to differing results.
787+
# To ensure consistent behavior, `rate` is explicitly cast to a floating-point type.
788+
# See: https://github.com/pyro-ppl/numpyro/issues/2181
789+
ftype = jnp.result_type(float)
790+
rate = jnp.astype(self.rate, ftype)
791+
788792
if (
789793
self.is_sparse
790794
and not isinstance(value, jax.core.Tracer)
791795
and jnp.size(value) > 1
792796
):
793797
shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value))
794-
rate = jnp.broadcast_to(self.rate, shape).reshape(-1)
798+
rate = jnp.broadcast_to(rate, shape).reshape(-1)
795799
nonzero = np.broadcast_to(jax.device_get(value) > 0, shape).reshape(-1)
796800
value = jnp.broadcast_to(value, shape).reshape(-1)
797801
sparse_value = value[nonzero]
798802
sparse_rate = rate[nonzero]
799803
return (
800-
jnp.asarray(-rate, dtype=jnp.result_type(float))
804+
jnp.asarray(-rate, dtype=ftype)
801805
.at[nonzero]
802806
.add(
803807
jnp.log(sparse_rate) * sparse_value - gammaln(sparse_value + 1),
804808
)
805809
.reshape(shape)
806810
)
807-
return (
808-
xlogy(jnp.astype(value, jnp.result_type(self.rate)), self.rate)
809-
- gammaln(value + 1)
810-
- self.rate
811-
)
811+
_value = jnp.astype(value, ftype)
812+
return xlogy(_value, rate) - gammaln(_value + 1.0) - rate
812813

813814
@property
814815
def mean(self) -> ArrayLike:

test/test_distributions.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4810,3 +4810,38 @@ def test_uniform_log_prob_outside_support():
48104810
match="Out-of-support values provided to log prob method. The value argument should be within the support.",
48114811
):
48124812
d.log_prob(-0.5)
4813+
4814+
4815+
@pytest.mark.parametrize("rate", [0, 1, 2, 5, 10, 1e-6, 1e2])
4816+
@pytest.mark.parametrize("value", [0, 1, 2, 5, 10])
4817+
def test_poisson_dtype_consistency(rate, value):
4818+
"""
4819+
Ensure that ``Poisson.log_prob`` is invariant to dtype differences in both
4820+
the rate parameter and the observed value.
4821+
4822+
This test checks that using integer vs. floating-point representations for:
4823+
- the rate (e.g., ``2`` vs ``2.0``), and
4824+
- the observed value (e.g., ``2`` vs ``2.0``),
4825+
4826+
yields identical log-probabilities across all combinations. This includes
4827+
edge cases such as zero rate, very small rates, large rates, and values
4828+
near integer boundaries.
4829+
4830+
The test guards against dtype-dependent behavior, where numerically
4831+
equivalent inputs produce inconsistent results due to type casting or
4832+
implementation details.
4833+
4834+
See: https://github.com/pyro-ppl/numpyro/issues/2181
4835+
"""
4836+
rates = [rate, float(rate)]
4837+
4838+
results = []
4839+
for r in rates:
4840+
d = dist.Poisson(r, validate_args=True)
4841+
results.append(d.log_prob(value))
4842+
4843+
ref = results[0]
4844+
for res in results[1:]:
4845+
assert jnp.allclose(res, ref), (
4846+
f"Inconsistent results for rate={rate}, value={value}: {results}"
4847+
)

0 commit comments

Comments
 (0)