Skip to content

Commit 3ea01fe

Browse files
committed
Added AbstractDistribution class
1 parent 3bb231d commit 3ea01fe

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

src/probabilit/modeling.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def update(item):
293293

294294
# Now that the node has been updated, update references to parents
295295
# to point to Nodes in the new copied graph instead of the old one.
296-
if isinstance(copied, (Distribution, ScalarFunctionTransform)):
296+
if isinstance(copied, (AbstractDistribution, ScalarFunctionTransform)):
297297
copied.args = tuple(update(arg) for arg in copied.args)
298298
copied.kwargs = {k: update(v) for (k, v) in copied.kwargs.items()}
299299
elif isinstance(copied, (VariadicTransform, BinaryTransform)):
@@ -328,9 +328,7 @@ def nodes(self):
328328

329329
def num_distribution_nodes(self):
330330
return sum(
331-
1
332-
for node in set(self.nodes())
333-
if isinstance(node, (Distribution, EmpiricalDistribution))
331+
1 for node in set(self.nodes()) if isinstance(node, AbstractDistribution)
334332
)
335333

336334
def sample(self, size=None, random_state=None, method=None):
@@ -384,11 +382,11 @@ def sample_from_quantiles(self, quantiles):
384382
# Sample all ancestors
385383
ancestors = G.subgraph(nx.ancestors(G, node))
386384
for ancestor in nx.topological_sort(ancestors):
387-
assert isinstance(ancestor, (Constant, Distribution))
385+
assert isinstance(ancestor, (Constant, AbstractDistribution))
388386
ancestor.samples_ = ancestor._sample(size=size)
389387

390388
# Sample the node
391-
assert isinstance(node, (Distribution, EmpiricalDistribution))
389+
assert isinstance(node, AbstractDistribution)
392390
node.samples_ = node._sample(q=next(columns))
393391

394392
# Go through all ancestor nodes and create a list [(var, corr), ...]
@@ -441,7 +439,7 @@ def sample_from_quantiles(self, quantiles):
441439
continue
442440
elif isinstance(node, Constant):
443441
node.samples_ = node._sample(size=size)
444-
elif isinstance(node, (Distribution, EmpiricalDistribution)):
442+
elif isinstance(node, AbstractDistribution):
445443
node.samples_ = node._sample(q=next(columns))
446444
elif isinstance(node, Transform):
447445
node.samples_ = node._sample()
@@ -454,12 +452,12 @@ def _is_initial_sampling_node(self):
454452
"""A node is an initial sample node iff:
455453
(1) It is a Distribution
456454
(2) None of its ancestors are Distributions (all are Constant/Transform)"""
457-
if isinstance(self, EmpiricalDistribution):
458-
return True
459455

460-
is_distribution = isinstance(self, Distribution)
456+
is_distribution = isinstance(self, AbstractDistribution)
461457
ancestors = set(self.nodes()) - set([self])
462-
ancestors_distr = any(isinstance(node, Distribution) for node in ancestors)
458+
ancestors_distr = any(
459+
isinstance(node, AbstractDistribution) for node in ancestors
460+
)
463461
return is_distribution and not ancestors_distr
464462

465463
def correlate(self, *variables, corr_mat):
@@ -604,7 +602,11 @@ def __repr__(self):
604602
return f"{type(self).__name__}({self.value})"
605603

606604

607-
class Distribution(Node, OverloadMixin):
605+
class AbstractDistribution(Node, OverloadMixin, abc.ABC):
606+
pass
607+
608+
609+
class Distribution(AbstractDistribution):
608610
"""A distribution is a sampling node with or without ancestors."""
609611

610612
def __init__(self, distr, *args, **kwargs):
@@ -647,7 +649,7 @@ def is_leaf(self):
647649
return list(self.get_parents()) == []
648650

649651

650-
class EmpiricalDistribution(Node, OverloadMixin):
652+
class EmpiricalDistribution(AbstractDistribution):
651653
"""A distribution is a sampling node with or without ancestors.
652654
653655
A thin wrapper around numpy.quantile."""
@@ -672,7 +674,7 @@ def get_parents(self):
672674
# ========================================================
673675

674676

675-
class Transform(Node, abc.ABC, OverloadMixin):
677+
class Transform(Node, OverloadMixin, abc.ABC):
676678
"""Transform nodes represent arithmetic operations."""
677679

678680
is_leaf = False
@@ -898,4 +900,5 @@ def transformed_function(*args, **kwargs):
898900
# =========================
899901

900902
cost = EmpiricalDistribution(data=[1, 2, 3, 3, 3, 3])
901-
(cost**2).sample(99, random_state=42)
903+
norm = Distribution("norm", loc=cost, scale=1)
904+
(norm**2).sample(99, random_state=42)

tests/test_modeling.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from probabilit.modeling import (
2+
EmpiricalDistribution,
23
Constant,
34
Log,
45
Exp,
@@ -179,6 +180,13 @@ def test_constant_idempotent():
179180
assert Constant(Constant(a)).value == Constant(a).value
180181

181182

183+
def test_empirical_distribution():
184+
# Test that an empirical distribution can be a parameter
185+
location = EmpiricalDistribution(data=[1, 2, 3, 3, 3, 3])
186+
result = Distribution("norm", loc=location, scale=1)
187+
(result**2).sample(99, random_state=42)
188+
189+
182190
if __name__ == "__main__":
183191
import pytest
184192

0 commit comments

Comments
 (0)