Skip to content
92 changes: 91 additions & 1 deletion src/probabilit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,38 @@
200. , 200. , 200. , 273.23522915,
235.11150117])

Multivariate distributions
--------------------------
Support for multivariate distributions (MVD) is implemented, but with constraints:

1. MVD must be a leaf node (its arguments cannot be other distributions)
2. its return values *must* be unpacked as marginals (slices)
3. only pseudo-random sampling is possible (LHS, Sobol, etc is ignored)

For instance, to create a Dirichlet distribution, we must unpack it as follows:

>>> d1, d2 = MultivariateDistribution("dirichlet", alpha=[1, 2])
>>> d1
MarginalDistribution(Distribution("dirichlet", alpha=[1, 2]), d=0)

Since the Direchlet distribution is defined on an (n-1) dimensional simplex,
the sum of the marginals is always 1. We can check this by computing:

>>> (d1 + d2).sample(5, random_state=0)
array([1., 1., 1., 1., 1.])
>>> d2.samples_.round(3)
array([0.559, 0.324, 0.284, 0.524, 0.599])

Here is an example with a multivariate normal distribution:

>>> cov = np.array([[1, 0.5], [0.5, 1]])
>>> n1, n2 = MultivariateDistribution("multivariate_normal", mean=[1, 2], cov=cov)
>>> n1.sample(5, random_state=0)
array([0.72058767, 3.13703525, 2.38930155, 1.50866787, 0.77018653])

Samplers
--------

The default sampling uses pseudo-random numbers. To use e.g. latin hybercube
sampling, pass the `method` argument into `.sample()`.

Expand Down Expand Up @@ -647,7 +677,13 @@ def unpack(arg):

# Sample from the distribution with inverse CDF
distribution = getattr(stats, self.distr)
return distribution(*args, **kwargs).ppf(q)
try:
return distribution(*args, **kwargs).ppf(q)
except AttributeError:
# Multivariate distributions do not have .ppf()
# isinstance(distribution, (multi_rv_generic, multi_rv_frozen))
seed = int(q[0] * 2**20) # Seed based on q
return distribution(*args, **kwargs).rvs(size=len(q), random_state=seed)

def get_parents(self):
# A distribution only has parents if it has parameters that are Nodes
Expand Down Expand Up @@ -972,6 +1008,58 @@ def transformed_function(*args, **kwargs):
return transformed_function


class MarginalDistribution(Transform):
"""A maginal distribution is a 'slice' of a multivariate distribution.

Examples
--------
>>> distr = Distribution("multinomial", n=10, p=[0.1, 0.2, 0.7])
>>> marginal_distr = MarginalDistribution(distr, d=0)
>>> marginal_distr
MarginalDistribution(Distribution("multinomial", n=10, p=[0.1, 0.2, 0.7]), d=0)
>>> marginal_distr.sample(5, random_state=0).astype(int)
array([2, 1, 2, 1, 1])
"""

is_leaf = False

def __init__(self, distr, d):
self.distr = distr
self.d = d
super().__init__()

def _sample(self):
# Simply slice the parent
return self.distr.samples_[:, self.d]

def get_parents(self):
yield self.distr

def __repr__(self):
return f"{type(self).__name__}({self.distr}, d={self.d})"


def MultivariateDistribution(distr, *args, **kwargs):
"""Factory function that yields marginal distributions.

Examples
--------
>>> p = [0.2, 0.3, 0.5] # Probability of each category
>>> m1, m2, m3 = MultivariateDistribution("multinomial", n=10, p=p)
>>> m1.sample(5, random_state=0).astype(int)
array([3, 2, 4, 2, 1])

Each category should sum to n=10:
>>> (m1 + m2 + m3).sample(5, random_state=0).astype(int)
array([10, 10, 10, 10, 10])
"""
distr = Distribution(distr, *args, **kwargs)

# Get dimensionality by sampling once
d = len(distr._sample(q=[0.5]).squeeze())
yield from (MarginalDistribution(distr, d=i) for i in range(d))


# ========================================================
if __name__ == "__main__":
import pytest
Expand Down Expand Up @@ -999,6 +1087,8 @@ def transformed_function(*args, **kwargs):
plt.scatter(a.samples_, b.samples_, s=2)
plt.show()

d1, d2, d3 = MultivariateDistribution("dirichlet", alpha=[1, 2, 3])

# =========================

cost = EmpiricalDistribution(data=[1, 2, 3, 3, 3, 3])
Expand Down