Description
The negative binomial distribution is well-defined with mean $\mu = 0$. Equivalently, it is well-defined with $p = 0$ in the probs/total_count parametrization used by NegativeBinomialProbs. In practice, NegativeBinomial2 fails validation with mean 0. More concerningly, NegativeBinomialProbs passes validation with probs = 0 but then behaves incorrectly, yielding nan log prob values. Forward sampling and evaluation of the distribution mean work correctly.
This derives from the fact that NegativeBinomialProbs(probs=jnp.array(0)) and NegativeBinomial2(mean=jnp.array(0)) both end up constructing a GammaPoisson with an inf rate:
|
rate = concentration / mean |
Action
I think this can and should be patched. I am happy to contribute a PR, but would appreciate some discussion on approach. In particular, do we want inf-rate GammaPoisson instances to be acceptable and well-defined? My instinct is yes but I can imagine there might be sharp bits.
Details and reprex
For NegativeBinomial2, the mean is given a constraints.positive, so construction fails with validate_args=True. Note: construction fails before validation with a divide-by-zero error if the mean is an int or float rather than a jax or numpy array.
import jax.numpy as jnp
import numpyro.distributions as dist
dist.NegativeBinomial2(mean = jnp.array(0), concentration = 1, validate_args=True) # fails validation
# ValueError: NegativeBinomial2 distribution got invalid mean parameter.
dist.NegativeBinomial2(mean = 0, concentration = 1, validate_args=True) # divide-by-zero error
# ZeroDivisionError: division by zero
For NegativeBinomialProbs, probs is given constraints.unit_interval. Construction and evaluation of the log prob therefore succeed with probs=jnp.array(0) regardless of whethervalidate_args=True (Note: as in the NegativeBinomial2 case, construction fails with a divide by zero error if the probs is a python int or float). But we get nan log prob values:
dist.NegativeBinomialProbs(total_count=1, probs=0, validate_args=True) # divide by 0 error
# ZeroDivisionError: float division by zero
my_neg_bin = dist.NegativeBinomialProbs(total_count=1, probs=jnp.array(0), validate_args=True) # construction succeeds
my_neg_bin.log_prob(0) # nan log prob; should be 0
# Array(nan, dtype=float32, weak_type=True)
my_neg_bin.log_prob(1) # nan log prob; should be -inf
# Array(nan, dtype=float32, weak_type=True)
Forward sampling and evaluation of the mean behaves correctly. In that case, the underlying GammaPoisson handles the inf rate gracefully:
my_neg_bin.mean # correct mean
# Array(0., dtype=float32, weak_type=True)
# correct forward sampling
with numpyro.handlers.seed(rng_seed=5):
samples = numpyro.sample('neg bin samples', my_neg_bin.expand((100,)))
all(samples == 0)
# True
Behavior is similar if we construct a zero-mean NegativeBinomial2 with validate_args=False. This makes sense; either way, we end up with a GammaPoisson with an inf rate.
my_neg_bin2.log_prob(0) # nan log prob, should be 0
# Array(nan, dtype=float32, weak_type=True)
my_neg_bin2.log_prob(1) # nan log prob, should be -inf
# Array(nan, dtype=float32, weak_type=True)
my_neg_bin2.mean # correct mean
# Array(0., dtype=float32, weak_type=True)
with numpyro.handlers.seed(rng_seed=5):
samples = numpyro.sample('neg bin samples', my_neg_bin2.expand((100,)))
all(samples == 0)
# True
Description
The negative binomial distribution is well-defined with mean$\mu = 0$ . Equivalently, it is well-defined with $p = 0$ in the
probs/total_countparametrization used byNegativeBinomialProbs. In practice,NegativeBinomial2fails validation with mean 0. More concerningly,NegativeBinomialProbspasses validation withprobs = 0but then behaves incorrectly, yieldingnanlog prob values. Forward sampling and evaluation of the distribution mean work correctly.This derives from the fact that
NegativeBinomialProbs(probs=jnp.array(0))andNegativeBinomial2(mean=jnp.array(0))both end up constructing aGammaPoissonwith aninfrate:numpyro/numpyro/distributions/conjugate.py
Line 483 in a1bcf5e
numpyro/numpyro/distributions/conjugate.py
Line 554 in a1bcf5e
Action
I think this can and should be patched. I am happy to contribute a PR, but would appreciate some discussion on approach. In particular, do we want
inf-rateGammaPoissoninstances to be acceptable and well-defined? My instinct is yes but I can imagine there might be sharp bits.Details and reprex
For
NegativeBinomial2, the mean is given aconstraints.positive, so construction fails withvalidate_args=True. Note: construction fails before validation with a divide-by-zero error if the mean is an int or float rather than a jax or numpy array.For
NegativeBinomialProbs, probs is givenconstraints.unit_interval. Construction and evaluation of the log prob therefore succeed withprobs=jnp.array(0)regardless of whethervalidate_args=True(Note: as in theNegativeBinomial2case, construction fails with a divide by zero error if the probs is a python int or float). But we getnanlog prob values:Forward sampling and evaluation of the mean behaves correctly. In that case, the underlying
GammaPoissonhandles theinfrate gracefully:Behavior is similar if we construct a zero-mean
NegativeBinomial2withvalidate_args=False. This makes sense; either way, we end up with aGammaPoissonwith aninfrate.