Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/probabilit/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ def plot(*variables, corr=None, **kwargs):

plot(a, b)
grid = plot(a)

from probabilit.modeling import MultivariateDistribution

cov = np.array([[1, 0.9], [0.9, 1]])
n1, n2 = MultivariateDistribution("multivariate_normal", mean=[1, 2], cov=cov)
from probabilit.inspection import plot

plot(n1, n2)
94 changes: 93 additions & 1 deletion src/probabilit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,40 @@
200. , 200. , 200. , 273.23522915,
235.11150117])

Multivariate distributions
--------------------------
Support for multivariate distributions (MVD) is implemented, but with constraints:

1. MVD must be a leaf node (its arguments cannot be other distributions)
2. its return values *must* be unpacked as marginals (slices)
3. only pseudo-random sampling is possible (LHS, Sobol, etc is ignored)

For instance, to create a Dirichlet distribution, we must unpack it as follows:

>>> d1, d2 = MultivariateDistribution("dirichlet", alpha=[1, 2])
>>> d1
MarginalDistribution(Distribution("dirichlet", alpha=[1, 2]), d=0)

Since the Direchlet distribution is defined on an (n-1) dimensional simplex,
the sum of the marginals is always 1. We can check this by computing:

>>> (d1 + d2).sample(5, random_state=0)
array([1., 1., 1., 1., 1.])
>>> d2.samples_.round(3)
array([0.559, 0.324, 0.284, 0.524, 0.599])

Here is an example with a multivariate normal distribution:

>>> cov = np.array([[1, 0.5], [0.5, 1]])
>>> n1, n2 = MultivariateDistribution("multivariate_normal", mean=[1, 2], cov=cov)
>>> n1.sample(5, random_state=0)
array([0.72058767, 3.13703525, 2.38930155, 1.50866787, 0.77018653])
>>> (n1 + n2).sample(5, random_state=0)
array([2.52848604, 5.31650094, 5.20076878, 4.06217341, 1.40748585])

Samplers
--------

The default sampling uses pseudo-random numbers. To use e.g. latin hybercube
sampling, pass the `method` argument into `.sample()`.

Expand Down Expand Up @@ -647,7 +679,13 @@ def unpack(arg):

# Sample from the distribution with inverse CDF
distribution = getattr(stats, self.distr)
return distribution(*args, **kwargs).ppf(q)
try:
return distribution(*args, **kwargs).ppf(q)
except AttributeError:
# Multivariate distributions do not have .ppf()
# isinstance(distribution, (multi_rv_generic, multi_rv_frozen))
seed = int(q[0] * 2**20) # Seed based on q
return distribution(*args, **kwargs).rvs(size=len(q), random_state=seed)

def get_parents(self):
# A distribution only has parents if it has parameters that are Nodes
Expand Down Expand Up @@ -972,6 +1010,58 @@ def transformed_function(*args, **kwargs):
return transformed_function


class MarginalDistribution(Transform):
"""A maginal distribution is a 'slice' of a multivariate distribution.

Examples
--------
>>> distr = Distribution("multinomial", n=10, p=[0.1, 0.2, 0.7])
>>> marginal_distr = MarginalDistribution(distr, d=0)
>>> marginal_distr
MarginalDistribution(Distribution("multinomial", n=10, p=[0.1, 0.2, 0.7]), d=0)
>>> marginal_distr.sample(5, random_state=0).astype(int)
array([2, 1, 2, 1, 1])
"""

is_leaf = False

def __init__(self, distr, d):
self.distr = distr
self.d = d
super().__init__()

def _sample(self):
# Simply slice the parent
return np.atleast_2d(self.distr.samples_)[:, self.d]

def get_parents(self):
yield self.distr

def __repr__(self):
return f"{type(self).__name__}({self.distr}, d={self.d})"


def MultivariateDistribution(distr, *args, **kwargs):
"""Factory function that yields marginal distributions.

Examples
--------
>>> p = [0.2, 0.3, 0.5] # Probability of each category
>>> m1, m2, m3 = MultivariateDistribution("multinomial", n=10, p=p)
>>> m1.sample(5, random_state=0).astype(int)
array([3, 2, 4, 2, 1])

Each category should sum to n=10:
>>> (m1 + m2 + m3).sample(5, random_state=0).astype(int)
array([10, 10, 10, 10, 10])
"""
distr = Distribution(distr, *args, **kwargs)

# Get dimensionality by sampling once
d = len(distr._sample(q=[0.5]).squeeze())
yield from (MarginalDistribution(distr, d=i) for i in range(d))


# ========================================================
if __name__ == "__main__":
import pytest
Expand Down Expand Up @@ -999,6 +1089,8 @@ def transformed_function(*args, **kwargs):
plt.scatter(a.samples_, b.samples_, s=2)
plt.show()

d1, d2, d3 = MultivariateDistribution("dirichlet", alpha=[1, 2, 3])

# =========================

cost = EmpiricalDistribution(data=[1, 2, 3, 3, 3, 3])
Expand Down