diff --git a/src/probabilit/inspection.py b/src/probabilit/inspection.py index d3f8053..97e6efe 100644 --- a/src/probabilit/inspection.py +++ b/src/probabilit/inspection.py @@ -38,3 +38,11 @@ def plot(*variables, corr=None, **kwargs): plot(a, b) grid = plot(a) + + from probabilit.modeling import MultivariateDistribution + + cov = np.array([[1, 0.9], [0.9, 1]]) + n1, n2 = MultivariateDistribution("multivariate_normal", mean=[1, 2], cov=cov) + from probabilit.inspection import plot + + plot(n1, n2) diff --git a/src/probabilit/modeling.py b/src/probabilit/modeling.py index 576f94d..653b4e7 100644 --- a/src/probabilit/modeling.py +++ b/src/probabilit/modeling.py @@ -114,8 +114,40 @@ 200. , 200. , 200. , 273.23522915, 235.11150117]) +Multivariate distributions +-------------------------- +Support for multivariate distributions (MVD) is implemented, but with constraints: + + 1. MVD must be a leaf node (its arguments cannot be other distributions) + 2. its return values *must* be unpacked as marginals (slices) + 3. only pseudo-random sampling is possible (LHS, Sobol, etc is ignored) + +For instance, to create a Dirichlet distribution, we must unpack it as follows: + +>>> d1, d2 = MultivariateDistribution("dirichlet", alpha=[1, 2]) +>>> d1 +MarginalDistribution(Distribution("dirichlet", alpha=[1, 2]), d=0) + +Since the Direchlet distribution is defined on an (n-1) dimensional simplex, +the sum of the marginals is always 1. We can check this by computing: + +>>> (d1 + d2).sample(5, random_state=0) +array([1., 1., 1., 1., 1.]) +>>> d2.samples_.round(3) +array([0.559, 0.324, 0.284, 0.524, 0.599]) + +Here is an example with a multivariate normal distribution: + +>>> cov = np.array([[1, 0.5], [0.5, 1]]) +>>> n1, n2 = MultivariateDistribution("multivariate_normal", mean=[1, 2], cov=cov) +>>> n1.sample(5, random_state=0) +array([0.72058767, 3.13703525, 2.38930155, 1.50866787, 0.77018653]) +>>> (n1 + n2).sample(5, random_state=0) +array([2.52848604, 5.31650094, 5.20076878, 4.06217341, 1.40748585]) + Samplers -------- + The default sampling uses pseudo-random numbers. To use e.g. latin hybercube sampling, pass the `method` argument into `.sample()`. @@ -647,7 +679,13 @@ def unpack(arg): # Sample from the distribution with inverse CDF distribution = getattr(stats, self.distr) - return distribution(*args, **kwargs).ppf(q) + try: + return distribution(*args, **kwargs).ppf(q) + except AttributeError: + # Multivariate distributions do not have .ppf() + # isinstance(distribution, (multi_rv_generic, multi_rv_frozen)) + seed = int(q[0] * 2**20) # Seed based on q + return distribution(*args, **kwargs).rvs(size=len(q), random_state=seed) def get_parents(self): # A distribution only has parents if it has parameters that are Nodes @@ -972,6 +1010,58 @@ def transformed_function(*args, **kwargs): return transformed_function +class MarginalDistribution(Transform): + """A maginal distribution is a 'slice' of a multivariate distribution. + + Examples + -------- + >>> distr = Distribution("multinomial", n=10, p=[0.1, 0.2, 0.7]) + >>> marginal_distr = MarginalDistribution(distr, d=0) + >>> marginal_distr + MarginalDistribution(Distribution("multinomial", n=10, p=[0.1, 0.2, 0.7]), d=0) + >>> marginal_distr.sample(5, random_state=0).astype(int) + array([2, 1, 2, 1, 1]) + """ + + is_leaf = False + + def __init__(self, distr, d): + self.distr = distr + self.d = d + super().__init__() + + def _sample(self): + # Simply slice the parent + return np.atleast_2d(self.distr.samples_)[:, self.d] + + def get_parents(self): + yield self.distr + + def __repr__(self): + return f"{type(self).__name__}({self.distr}, d={self.d})" + + +def MultivariateDistribution(distr, *args, **kwargs): + """Factory function that yields marginal distributions. + + Examples + -------- + >>> p = [0.2, 0.3, 0.5] # Probability of each category + >>> m1, m2, m3 = MultivariateDistribution("multinomial", n=10, p=p) + >>> m1.sample(5, random_state=0).astype(int) + array([3, 2, 4, 2, 1]) + + Each category should sum to n=10: + >>> (m1 + m2 + m3).sample(5, random_state=0).astype(int) + array([10, 10, 10, 10, 10]) + """ + distr = Distribution(distr, *args, **kwargs) + + # Get dimensionality by sampling once + d = len(distr._sample(q=[0.5]).squeeze()) + yield from (MarginalDistribution(distr, d=i) for i in range(d)) + + # ======================================================== if __name__ == "__main__": import pytest @@ -999,6 +1089,8 @@ def transformed_function(*args, **kwargs): plt.scatter(a.samples_, b.samples_, s=2) plt.show() + d1, d2, d3 = MultivariateDistribution("dirichlet", alpha=[1, 2, 3]) + # ========================= cost = EmpiricalDistribution(data=[1, 2, 3, 3, 3, 3])