diff --git a/distrax/_src/distributions/poisson.py b/distrax/_src/distributions/poisson.py new file mode 100644 index 00000000..e1852cc8 --- /dev/null +++ b/distrax/_src/distributions/poisson.py @@ -0,0 +1,132 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Poisson distribution.""" + +from typing import Tuple, Union + +import chex +from distrax._src.distributions import distribution +from distrax._src.utils import conversion +from distrax._src.utils import math +import jax +import jax.numpy as jnp +from tensorflow_probability.substrates import jax as tfp + +tfd = tfp.distributions + +Array = chex.Array +Numeric = chex.Numeric +PRNGKey = chex.PRNGKey +EventT = distribution.EventT + + +class Poisson(distribution.Distribution): + """Poisson distribution with a rate parameter.""" + + equiv_tfp_cls = tfd.Poisson + + def __init__(self, rate: Numeric): + """Initializes a Poisson distribution. + + Args: + rate: Rate of the distribution. + """ + super().__init__() + self._rate = conversion.as_float_array(rate) + self._batch_shape = self._rate.shape + + @property + def event_shape(self) -> Tuple[int, ...]: + """Shape of event of distribution samples.""" + return () + + @property + def batch_shape(self) -> Tuple[int, ...]: + """Shape of batch of distribution samples.""" + return self._batch_shape + + @property + def rate(self) -> Array: + """Mean of the distribution.""" + return jnp.broadcast_to(self._rate, self.batch_shape) + + def _sample_n(self, key: PRNGKey, n: int) -> Array: + """See `Distribution._sample_n`.""" + out_shape = (n,) + self.batch_shape + return jax.random.poisson(key, self.rate, out_shape) + + def log_prob(self, value: EventT) -> Array: + """See `Distribution.log_prob`.""" + return ( + (jnp.log(self.rate) * value) + - jax.scipy.special.gammaln(value + 1) + - self.rate + ) + + def cdf(self, value: EventT) -> Array: + """See `Distribution.cdf`.""" + x = jnp.floor(value) + 1 + return jax.scipy.special.gammaincc(x, self.rate) + + def log_cdf(self, value: EventT) -> Array: + """See `Distribution.log_cdf`.""" + return jnp.log(self.cdf(value)) + + def mean(self) -> Array: + """Calculates the mean.""" + return self.rate + + def stddev(self) -> Array: + """Calculates the standard deviation.""" + return jnp.sqrt(self.rate) + + def variance(self) -> Array: + """Calculates the variance.""" + return self.rate + + def mode(self) -> Array: + """Calculates the mode.""" + return jnp.ceil(self.rate) - 1 + + def __getitem__(self, index) -> 'Poisson': + """See `Distribution.__getitem__`.""" + index = distribution.to_batch_shape_index(self.batch_shape, index) + return Poisson(rate=self.rate[index]) + + +def _kl_divergence_poisson_poisson( + dist1: Union[Poisson, tfd.Poisson], + dist2: Union[Poisson, tfd.Poisson], + *unused_args, + **unused_kwargs, +) -> Array: + """Batched KL divergence KL(dist1 || dist2) between two poisson distributions. + + Args: + dist1: A poisson distribution. + dist2: A poisson distribution. + + Returns: + Batchwise `KL(dist1 || dist2)`. + """ + distance = dist1.rate - dist2.rate + diff_log_scale = jnp.log(dist1.rate) - jnp.log(dist2.rate) + return math.multiply_no_nan(dist1.rate, diff_log_scale) - distance + + +# Register the KL functions with TFP. +tfd.RegisterKL(Poisson, Poisson)(_kl_divergence_poisson_poisson) +tfd.RegisterKL(Poisson, Poisson.equiv_tfp_cls)(_kl_divergence_poisson_poisson) +tfd.RegisterKL(Poisson.equiv_tfp_cls, Poisson)(_kl_divergence_poisson_poisson) diff --git a/distrax/_src/distributions/poisson_test.py b/distrax/_src/distributions/poisson_test.py new file mode 100644 index 00000000..e1057cf0 --- /dev/null +++ b/distrax/_src/distributions/poisson_test.py @@ -0,0 +1,245 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `poisson.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized +import chex +from distrax._src.distributions import poisson +from distrax._src.utils import equivalence +import jax.numpy as jnp +import numpy as np + + +class PoissonTest(equivalence.EquivalenceTest): + + def setUp(self): + super().setUp() + self._init_distr_cls(poisson.Poisson) + + @parameterized.named_parameters( + ('1d poisson', (1,)), + ('2d poisson', (np.ones(2),)), + ('rank 2 poisson', (np.zeros((3, 2)),)), + ) + def test_event_shape(self, distr_params): + super()._test_event_shape(distr_params, dict()) + + @chex.all_variants + @parameterized.named_parameters( + ('1d poisson, no shape', (1,), ()), + ('1d poisson, int shape', (1,), 1), + ('1d poisson, 1-tuple shape', (1,), (1,)), + ('1d poisson, 2-tuple shape', (1,), (2, 2)), + ('2d poisson, no shape', (np.ones(2),), ()), + ('2d poisson, int shape', ([1, 1],), 1), + ('2d poisson, 1-tuple shape', (np.ones(2),), (1,)), + ('2d poisson, 2-tuple shape', ([1, 1],), (2, 2)), + ( + 'rank 2 poisson, 2-tuple shape', + (np.ones((3, 2)),), + (2, 2), + ), + ) + def test_sample_shape(self, distr_params, sample_shape): + distr_params = (np.asarray(distr_params[0], dtype=np.float32),) + super()._test_sample_shape(distr_params, dict(), sample_shape) + + @chex.all_variants + @parameterized.named_parameters( + ('float32', jnp.float32), ('float64', jnp.float64) + ) + def test_sample_dtype(self, dtype): + dist = self.distrax_cls(loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype)) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) + + @chex.all_variants + @parameterized.named_parameters( + ('1d poisson, no shape', (1,), ()), + ('1d poisson, int shape', (1,), 1), + ('1d poisson, 1-tuple shape', (1,), (1,)), + ('1d poisson, 2-tuple shape', (1,), (2, 2)), + ('2d poisson, no shape', (np.ones(2),), ()), + ('2d poisson, int shape', ([1, 1],), 1), + ('2d poisson, 1-tuple shape', (np.ones(2),), (1,)), + ('2d poisson, 2-tuple shape', ([1, 1],), (2, 2)), + ( + 'rank 2 poisson, 2-tuple shape', + (np.ones((3, 2)),), + (2, 2), + ), + ) + def test_sample_and_log_prob(self, distr_params, sample_shape): + distr_params = (np.asarray(distr_params[0], dtype=np.float32),) + super()._test_sample_and_log_prob( + dist_args=distr_params, + dist_kwargs=dict(), + sample_shape=sample_shape, + assertion_fn=self.assertion_fn(rtol=2e-2), + ) + + @chex.all_variants + @parameterized.named_parameters( + ('1d dist, 1d value', (1,), 1), + ('1d dist, 2d value', (0.5,), np.array([1, 2])), + ('1d dist, 2d value as list', (0.5,), [1, 2]), + ('2d dist, 1d value', (0.5 + np.zeros(2),), 1), + ('2d dist, 2d value', ([0.1, 0.5],), np.array([1, 2])), + ('1d dist, 1d value, edge case', (1,), 200), + ) + def test_log_prob(self, distr_params, value): + distr_params = (np.asarray(distr_params[0], dtype=np.float32),) + value = np.asarray(value, dtype=np.float32) + super()._test_attribute( + attribute_string='log_prob', + dist_args=distr_params, + call_args=(value,), + assertion_fn=self.assertion_fn(rtol=2e-2), + ) + + @chex.all_variants + @parameterized.named_parameters( + ('1d dist, 1d value', (1,), 1), + ('1d dist, 2d value', (0.5,), np.array([1, 2])), + ('1d dist, 2d value as list', (0.5,), [1, 2]), + ('2d dist, 1d value', (0.5 + np.zeros(2),), 1), + ('2d dist, 2d value', ([0.1, 0.5],), np.array([1, 2])), + ('1d dist, 1d value, edge case', (1,), 200), + ) + def test_prob(self, distr_params, value): + distr_params = (np.asarray(distr_params[0], dtype=np.float32),) + value = np.asarray(value, dtype=np.float32) + super()._test_attribute( + attribute_string='prob', + dist_args=distr_params, + call_args=(value,), + assertion_fn=self.assertion_fn(rtol=2e-2), + ) + + @chex.all_variants + @parameterized.named_parameters( + ('1d dist, 1d value', (1,), 1), + ('1d dist, 2d value', (0.5,), np.array([1, 2])), + ('1d dist, 2d value as list', (0.5,), [1, 2]), + ('2d dist, 1d value', (0.5 + np.zeros(2),), 1), + ('2d dist, 2d value', ([0.1, 0.5],), np.array([1, 2])), + ('1d dist, 1d value, edge case', (1,), 200), + ) + def test_cdf(self, distr_params, value): + distr_params = (np.asarray(distr_params[0], dtype=np.float32),) + value = np.asarray(value, dtype=np.float32) + super()._test_attribute( + attribute_string='cdf', + dist_args=distr_params, + call_args=(value,), + assertion_fn=self.assertion_fn(rtol=2e-2), + ) + + @chex.all_variants + @parameterized.named_parameters( + ('1d dist, 1d value', (1,), 1), + ('1d dist, 2d value', (0.5,), np.array([1, 2])), + ('1d dist, 2d value as list', (0.5,), [1, 2]), + ('2d dist, 1d value', (0.5 + np.zeros(2),), 1), + ('2d dist, 2d value', ([0.1, 0.5],), np.array([1, 2])), + ('1d dist, 1d value, edge case', (1,), 200), + ) + def test_log_cdf(self, distr_params, value): + distr_params = (np.asarray(distr_params[0], dtype=np.float32),) + value = np.asarray(value, dtype=np.float32) + super()._test_attribute( + attribute_string='log_cdf', + dist_args=distr_params, + call_args=(value,), + assertion_fn=self.assertion_fn(rtol=2e-2), + ) + + @chex.all_variants + @parameterized.named_parameters( + ('1d dist, 1d value', (1,), 1), + ('1d dist, 2d value', (0.5,), np.array([1, 2])), + ('1d dist, 2d value as list', (0.5,), [1, 2]), + ('2d dist, 1d value', (0.5 + np.zeros(2),), 1), + ('2d dist, 2d value', ([0.1, 0.5],), np.array([1, 2])), + ('1d dist, 1d value, edge case', (1,), 200), + ) + def test_log_survival_function(self, distr_params, value): + distr_params = (np.asarray(distr_params[0], dtype=np.float32),) + value = np.asarray(value, dtype=np.float32) + super()._test_attribute( + attribute_string='log_survival_function', + dist_args=distr_params, + call_args=(value,), + assertion_fn=self.assertion_fn(rtol=2e-2), + ) + + @chex.all_variants(with_pmap=False) + @parameterized.named_parameters( + ('mean', ([0.1, 1.0, 0.5],), 'mean'), + ('variance', ([0.1, 1.0, 0.5],), 'variance'), + ('stddev', ([0.1, 1.0, 0.5],), 'stddev'), + ('mode', ([0.1, 1.0, 0.5],), 'mode'), + ) + def test_method(self, distr_params, function_string): + distr_params = (np.asarray(distr_params[0], dtype=np.float32),) + super()._test_attribute( + attribute_string=function_string, + dist_args=distr_params, + assertion_fn=self.assertion_fn(rtol=2e-2), + ) + + @chex.all_variants(with_pmap=False) + @parameterized.named_parameters( + ('kl distrax_to_distrax', 'kl_divergence', 'distrax_to_distrax'), + ('kl distrax_to_tfp', 'kl_divergence', 'distrax_to_tfp'), + ('kl tfp_to_distrax', 'kl_divergence', 'tfp_to_distrax'), + ('cross-ent distrax_to_distrax', 'cross_entropy', 'distrax_to_distrax'), + ('cross-ent distrax_to_tfp', 'cross_entropy', 'distrax_to_tfp'), + ('cross-ent tfp_to_distrax', 'cross_entropy', 'tfp_to_distrax'), + ) + def test_with_two_distributions(self, function_string, mode_string): + rng = np.random.default_rng(42) + super()._test_with_two_distributions( + attribute_string=function_string, + mode_string=mode_string, + dist1_kwargs={ + 'rate': jnp.exp(rng.normal(size=(4, 1, 2))), + }, + dist2_kwargs={ + 'rate': jnp.exp(rng.normal(size=(3, 2))), + }, + assertion_fn=self.assertion_fn(rtol=2e-2), + ) + + def test_jitable(self): + super()._test_jittable((1.0,)) + + @parameterized.named_parameters( + ('single element', 2), + ('range', slice(-1)), + ('range_2', (slice(None), slice(-1))), + ('ellipsis', (Ellipsis, -1)), + ) + def test_slice(self, slice_): + rng = np.random.default_rng(42) + rate = jnp.exp(jnp.array(rng.normal(size=(3, 4, 5)))) + dist = self.distrax_cls(rate=rate) + self.assertion_fn(rtol=2e-2)(dist[slice_].mean(), rate[slice_]) + + +if __name__ == '__main__': + absltest.main()