@@ -293,7 +293,7 @@ def update(item):
293293
294294 # Now that the node has been updated, update references to parents
295295 # to point to Nodes in the new copied graph instead of the old one.
296- if isinstance (copied , (Distribution , ScalarFunctionTransform )):
296+ if isinstance (copied , (AbstractDistribution , ScalarFunctionTransform )):
297297 copied .args = tuple (update (arg ) for arg in copied .args )
298298 copied .kwargs = {k : update (v ) for (k , v ) in copied .kwargs .items ()}
299299 elif isinstance (copied , (VariadicTransform , BinaryTransform )):
@@ -328,9 +328,7 @@ def nodes(self):
328328
329329 def num_distribution_nodes (self ):
330330 return sum (
331- 1
332- for node in set (self .nodes ())
333- if isinstance (node , (Distribution , EmpiricalDistribution ))
331+ 1 for node in set (self .nodes ()) if isinstance (node , AbstractDistribution )
334332 )
335333
336334 def sample (self , size = None , random_state = None , method = None ):
@@ -384,11 +382,11 @@ def sample_from_quantiles(self, quantiles):
384382 # Sample all ancestors
385383 ancestors = G .subgraph (nx .ancestors (G , node ))
386384 for ancestor in nx .topological_sort (ancestors ):
387- assert isinstance (ancestor , (Constant , Distribution ))
385+ assert isinstance (ancestor , (Constant , AbstractDistribution ))
388386 ancestor .samples_ = ancestor ._sample (size = size )
389387
390388 # Sample the node
391- assert isinstance (node , ( Distribution , EmpiricalDistribution ) )
389+ assert isinstance (node , AbstractDistribution )
392390 node .samples_ = node ._sample (q = next (columns ))
393391
394392 # Go through all ancestor nodes and create a list [(var, corr), ...]
@@ -441,7 +439,7 @@ def sample_from_quantiles(self, quantiles):
441439 continue
442440 elif isinstance (node , Constant ):
443441 node .samples_ = node ._sample (size = size )
444- elif isinstance (node , ( Distribution , EmpiricalDistribution ) ):
442+ elif isinstance (node , AbstractDistribution ):
445443 node .samples_ = node ._sample (q = next (columns ))
446444 elif isinstance (node , Transform ):
447445 node .samples_ = node ._sample ()
@@ -454,12 +452,12 @@ def _is_initial_sampling_node(self):
454452 """A node is an initial sample node iff:
455453 (1) It is a Distribution
456454 (2) None of its ancestors are Distributions (all are Constant/Transform)"""
457- if isinstance (self , EmpiricalDistribution ):
458- return True
459455
460- is_distribution = isinstance (self , Distribution )
456+ is_distribution = isinstance (self , AbstractDistribution )
461457 ancestors = set (self .nodes ()) - set ([self ])
462- ancestors_distr = any (isinstance (node , Distribution ) for node in ancestors )
458+ ancestors_distr = any (
459+ isinstance (node , AbstractDistribution ) for node in ancestors
460+ )
463461 return is_distribution and not ancestors_distr
464462
465463 def correlate (self , * variables , corr_mat ):
@@ -604,7 +602,11 @@ def __repr__(self):
604602 return f"{ type (self ).__name__ } ({ self .value } )"
605603
606604
607- class Distribution (Node , OverloadMixin ):
605+ class AbstractDistribution (Node , OverloadMixin , abc .ABC ):
606+ pass
607+
608+
609+ class Distribution (AbstractDistribution ):
608610 """A distribution is a sampling node with or without ancestors."""
609611
610612 def __init__ (self , distr , * args , ** kwargs ):
@@ -647,7 +649,7 @@ def is_leaf(self):
647649 return list (self .get_parents ()) == []
648650
649651
650- class EmpiricalDistribution (Node , OverloadMixin ):
652+ class EmpiricalDistribution (AbstractDistribution ):
651653 """A distribution is a sampling node with or without ancestors.
652654
653655 A thin wrapper around numpy.quantile."""
@@ -672,7 +674,7 @@ def get_parents(self):
672674# ========================================================
673675
674676
675- class Transform (Node , abc .ABC , OverloadMixin ):
677+ class Transform (Node , OverloadMixin , abc .ABC ):
676678 """Transform nodes represent arithmetic operations."""
677679
678680 is_leaf = False
@@ -898,4 +900,5 @@ def transformed_function(*args, **kwargs):
898900 # =========================
899901
900902 cost = EmpiricalDistribution (data = [1 , 2 , 3 , 3 , 3 , 3 ])
901- (cost ** 2 ).sample (99 , random_state = 42 )
903+ norm = Distribution ("norm" , loc = cost , scale = 1 )
904+ (norm ** 2 ).sample (99 , random_state = 42 )
0 commit comments