Skip to content

Handle zero-mean Negative Binomial distributions #2193

@dylanhmorris

Description

@dylanhmorris

Description

The negative binomial distribution is well-defined with mean $\mu = 0$. Equivalently, it is well-defined with $p = 0$ in the probs/total_count parametrization used by NegativeBinomialProbs. In practice, NegativeBinomial2 fails validation with mean 0. More concerningly, NegativeBinomialProbs passes validation with probs = 0 but then behaves incorrectly, yielding nan log prob values. Forward sampling and evaluation of the distribution mean work correctly.

This derives from the fact that NegativeBinomialProbs(probs=jnp.array(0)) and NegativeBinomial2(mean=jnp.array(0)) both end up constructing a GammaPoisson with an inf rate:

rate = 1.0 / probs - 1.0

rate = concentration / mean

Action

I think this can and should be patched. I am happy to contribute a PR, but would appreciate some discussion on approach. In particular, do we want inf-rate GammaPoisson instances to be acceptable and well-defined? My instinct is yes but I can imagine there might be sharp bits.

Details and reprex

For NegativeBinomial2, the mean is given a constraints.positive, so construction fails with validate_args=True. Note: construction fails before validation with a divide-by-zero error if the mean is an int or float rather than a jax or numpy array.

import jax.numpy as jnp
import numpyro.distributions as dist

dist.NegativeBinomial2(mean = jnp.array(0), concentration = 1, validate_args=True) # fails validation
# ValueError: NegativeBinomial2 distribution got invalid mean parameter.
dist.NegativeBinomial2(mean = 0, concentration = 1, validate_args=True) # divide-by-zero error
# ZeroDivisionError: division by zero

For NegativeBinomialProbs, probs is given constraints.unit_interval. Construction and evaluation of the log prob therefore succeed with probs=jnp.array(0) regardless of whethervalidate_args=True (Note: as in the NegativeBinomial2 case, construction fails with a divide by zero error if the probs is a python int or float). But we get nan log prob values:

dist.NegativeBinomialProbs(total_count=1, probs=0, validate_args=True) # divide by 0 error
# ZeroDivisionError: float division by zero
my_neg_bin = dist.NegativeBinomialProbs(total_count=1, probs=jnp.array(0), validate_args=True) # construction succeeds
my_neg_bin.log_prob(0) # nan log prob; should be 0
# Array(nan, dtype=float32, weak_type=True)
my_neg_bin.log_prob(1) # nan log prob; should be -inf
# Array(nan, dtype=float32, weak_type=True)

Forward sampling and evaluation of the mean behaves correctly. In that case, the underlying GammaPoisson handles the inf rate gracefully:

my_neg_bin.mean # correct mean 
# Array(0., dtype=float32, weak_type=True)
# correct forward sampling
with numpyro.handlers.seed(rng_seed=5):
    samples = numpyro.sample('neg bin samples', my_neg_bin.expand((100,)))

all(samples == 0)
# True

Behavior is similar if we construct a zero-mean NegativeBinomial2 with validate_args=False. This makes sense; either way, we end up with a GammaPoisson with an inf rate.

my_neg_bin2.log_prob(0) # nan log prob, should be 0
# Array(nan, dtype=float32, weak_type=True)
my_neg_bin2.log_prob(1) # nan log prob, should be -inf
# Array(nan, dtype=float32, weak_type=True)
my_neg_bin2.mean # correct mean
# Array(0., dtype=float32, weak_type=True)
with numpyro.handlers.seed(rng_seed=5):
    samples = numpyro.sample('neg bin samples', my_neg_bin2.expand((100,)))

all(samples == 0)
# True

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions