33--------
44
55Probabilit lets the user perform Monte-Carlo sampling using a high-level
6- modeling language.
6+ modeling language, which creates a computational graph .
77
8- As a first look at the modeling language, let us do some computations.
9- We'll use constants before looking at random variables.
10-
11- Random samples can be drawn from a node using .sample(), which delegates to scipy:
8+ For instance, to compute the shipping cost of a box where we are uncertain
9+ about the measurements:
1210
1311>>> rng = np.random.default_rng(42)
14- >>> a = Constant(1)
15- >>> a.sample(5, random_state=rng)
16- array([1, 1, 1, 1, 1])
17-
18- Computational graphs can be built user overloaded Python operators.
19- Mixing numbers with nodes is allowed, but at least one expression or term
20- must be a probabilit class instance:
12+ >>> box_height = Distribution("norm", loc=0.5, scale=0.01)
13+ >>> box_width = Distribution("norm", loc=1, scale=0.01)
14+ >>> box_depth = Distribution("norm", loc=0.8, scale=0.01)
15+ >>> box_volume = box_height * box_width * box_depth
16+ >>> price_per_sqm = 50
17+ >>> price = box_volume * price_per_sqm
18+ >>> samples = price.sample(999, random_state=rng)
19+ >>> float(np.mean(samples))
20+ 20.00139737515...
21+
22+ Distributions are built on top of scipy, so "norm" refers to the name of the
23+ normal distribution as given in `scipy.stats`, and the arguments to the distribution
24+ must also match those given by `scipy.stats.norm`.
25+
26+ Here is another example showing composite distributions, where the argument
27+ to one distribution is another distribution:
28+
29+ >>> eggs_per_nest = Distribution("poisson", mu=3)
30+ >>> survivial_prob = Distribution("beta", a=10, b=15)
31+ >>> survived = Distribution("binom", n=eggs_per_nest, p=survivial_prob)
32+ >>> survived.sample(9, random_state=rng)
33+ array([1., 1., 1., 0., 2., 1., 1., 0., 2.])
34+
35+ To understand and examine the modeling language, we can do some computations
36+ with constants. The computational graph carries out arithmetic operations lazily
37+ once a model is sampled. Mixing numbers with nodes is allowed, but at least one
38+ expression or term must be a probabilit class instance:
2139
40+ >>> a = Constant(1)
2241>>> (a * 3 + 5).sample(5, random_state=rng)
2342array([8, 8, 8, 8, 8])
2443>>> Add(10, 5, 5).sample(5, random_state=rng)
2544array([20, 20, 20, 20, 20])
2645
27- Of course, things get more interesting with probability distributions.
28- The names and arguments correspond to scipy distributions (scipy.stats).
46+ Let us build a more compliated expression:
2947
3048>>> a = Distribution("norm", loc=5, scale=1)
3149>>> b = Distribution("expon", scale=1)
32- >>> product = a * b
33-
34- The product above is not evaluated untill we sample from it.
35-
36- >>> product.sample(5, random_state=rng)
37- array([ 3.32357208, 7.25992397, 13.68470082, 8.80523473, 2.31314151])
38-
39- Let us build a more compliated expression:
40-
4150>>> expression = a**b + a * b + 5 * b
4251
4352Every unique node in this expression can be found:
5665Sampling the expression is simple enough:
5766
5867>>> expression.sample(5, random_state=rng)
59- array([81.47571166 , 36.25874807 , 4.04413643 , 1.78245506, 16.86301139 ])
68+ array([ 2.70764145 , 36.58578812 , 7.07064239 , 1.84433247, 3.90951632 ])
6069
6170Sampling the expression has the side effect that `.samples_` is populated on
6271*every* node in the expression, for instance:
6372
6473>>> a.samples_
65- array([4.6702595 , 5.36880945, 4.85768145 , 5.1372535 , 5.9448457 ])
74+ array([4.51589278, 4.37788659 , 5.25960812 , 5.80609507, 4.33770499 ])
6675
6776To sample using e.g. Latin Hypercube, do the following:
6877
6978>>> from scipy.stats.qmc import LatinHypercube
70- >>> d = expression.get_dimensionality ()
79+ >>> d = expression.get_number_of_distribution_nodes ()
7180>>> hypercube = LatinHypercube(d=d, rng=rng)
7281>>> hypercube_samples = hypercube.random(5) # Draw 5 samples
7382>>> expression.sample_from_cube(hypercube_samples)
7483array([ 1.20438726, 12.40283222, 5.02130766, 16.45109076, 77.12874028])
7584
76- Here is a more complex expression:
85+ Here is an even more complex expression:
7786
7887>>> a = Distribution("norm", loc=0, scale=1)
7988>>> b = Distribution("norm", loc=0, scale=2)
8089>>> c = Distribution("norm", loc=0, scale=3)
8190>>> expression = a*a - Add(a, b, c) + Abs(b)**Abs(c) + Exp(1 / Abs(c))
8291>>> expression.sample(5, random_state=rng)
83- array([-3.75434563, 5.84160178, 50.58877597, -1.32687877, 81.00831756])
84-
92+ array([ 4.70542018, 14.43250192, 6.74494838, -0.14020459, -3.27334554])
8593
8694Functions
95+ ---------
8796
8897If you have a function that is not an arithmetic expression, you can still
8998Monte-Carlo simulate through it with the `scalar_transform` decorator, which
90- will pass each sample through the computation node in a loop when we sample :
99+ will pass each sample through the computation node in a loop:
91100
92101>>> def function(a, b):
93102... if a > 0:
100109
101110>>> a = Distribution("norm", loc=0, scale=1)
102111>>> b = Distribution("norm", loc=0, scale=2)
103- >>> expression = function(a, b) # Function is not called here
112+ >>> expression = function(a, b) # Function is not actually called here
104113
105114Now sample 'through' the function:
106115
107116>>> expression.sample(5, random_state=rng)
108- array([ 0. , 0. , -0.13902087, 1.01335768, 0. ])
117+ array([0. , 0. , 0.45555522, 0. , 0. ])
109118"""
110119
111120import operator
112121import functools
113122import numpy as np
114123import numbers
115- import dataclasses
116124from scipy import stats
117125import abc
118126import itertools
@@ -204,55 +212,68 @@ def python_to_prob(argument):
204212# is determined by the graph structure.
205213
206214
207- @dataclasses .dataclass
208215class Node (abc .ABC ):
209216 """A node in the computational graph."""
210217
211- id_iter = itertools .count () # Everyone gets a unique ID
218+ id_iter = itertools .count () # Every node gets a unique ID
219+
220+ def __init__ (self ):
221+ self ._id = next (self .id_iter )
212222
213223 def __eq__ (self , other ):
214224 return self ._id == other ._id
215225
216226 def __hash__ (self ):
217227 return self ._id
218228
219- def __post_init__ (self ):
220- self ._id = next (self .id_iter )
221-
222229 def nodes (self ):
223- """Yields all ancestors using depth-first-search, including `self`."""
230+ """Yields `self` and all ancestors using depth-first-search.
231+
232+ Examples
233+ --------
234+ >>> expression = Distribution("norm") - 2**Constant(2)
235+ >>> for node in expression.nodes():
236+ ... print(node)
237+ Subtract(Distribution("norm"), Power(Constant(2), Constant(2)))
238+ Power(Constant(2), Constant(2))
239+ Constant(2)
240+ Constant(2)
241+ Distribution("norm")
242+ """
224243 queue = [(self )]
225244 while queue :
226245 yield (node := queue .pop ())
227246 queue .extend (node .get_parents ())
228247
229- def get_dimensionality (self ):
248+ def get_number_of_distribution_nodes (self ):
230249 return sum (1 for node in set (self .nodes ()) if isinstance (node , Distribution ))
231250
232251 def sample (self , size = None , random_state = None ):
233- """Assign samples to self .samples_ rescursively ."""
252+ """Sample the current node and assign to all node .samples_."""
234253 size = 1 if size is None else size
235254 random_state = np .random .default_rng (random_state )
236255
237256 # Draw a cube of random variables in [0, 1]
238- cube = random_state .random ((size , self .get_dimensionality ()))
257+ cube = random_state .random ((size , self .get_number_of_distribution_nodes ()))
239258
240259 return self .sample_from_cube (cube )
241260
242261 def sample_from_cube (self , cube ):
243- """Use samples from a cube of shape (dimensionality, num_samples)."""
262+ """Use samples from a cube of quantiles in [0, 1] to sample all
263+ distributions. The cube must have shape (dimensionality, num_samples)."""
244264 assert nx .is_directed_acyclic_graph (self .to_graph ())
245-
246265 size , n_dim = cube .shape
247- assert n_dim == self .get_dimensionality ()
266+ assert n_dim == self .get_number_of_distribution_nodes ()
267+
268+ # Prepare columns of quantiles, one column for each Distribution
248269 columns = iter (list (cube .T ))
249270
250- # Clear any samples that might exist
271+ # Clear any samples that might exist in the graph
251272 for node in set (self .nodes ()):
252273 if hasattr (node , "samples_" ):
253274 delattr (node , "samples_" )
254275
255- # Sample leaf nodes that are distributions first
276+ # Start with initial sampling nodes, which contain independent variables
256277 initial_sampling_nodes = [
257278 node for node in set (self .nodes ()) if node ._is_initial_sampling_node ()
258279 ]
@@ -272,11 +293,8 @@ def sample_from_cube(self, cube):
272293
273294 # TODO: correlate the samples
274295
275- # Iterate over the remaining nodes and sample
276- remaining_nodes = nx .topological_sort (G )
277-
278296 # Iterate from leaf nodes and up to parent
279- for node in remaining_nodes :
297+ for node in nx . topological_sort ( G ) :
280298 if hasattr (node , "samples_" ):
281299 continue
282300 elif isinstance (node , Constant ):
@@ -285,6 +303,8 @@ def sample_from_cube(self, cube):
285303 node .samples_ = node ._sample (q = next (columns ))
286304 elif isinstance (node , Transform ):
287305 node .samples_ = node ._sample ()
306+ else :
307+ raise TypeError ("Node must be Constant, Distribution or Transform." )
288308
289309 return self .samples_
290310
@@ -333,7 +353,7 @@ def to_graph(self):
333353
334354
335355class OverloadMixin :
336- """Overloads dunder (double underscore) methods."""
356+ """Overloads dunder (double underscore) methods for easier modeling ."""
337357
338358 def __add__ (self , other ):
339359 return Add (self , other )
@@ -375,7 +395,7 @@ def __abs__(self):
375395class Constant (Node , OverloadMixin ):
376396 """A constant is a number."""
377397
378- is_leaf = True
398+ is_leaf = True # A Constant is always a leaf node
379399
380400 def __init__ (self , value ):
381401 self .value = value
@@ -387,7 +407,7 @@ def _sample(self, size=None, random_state=None):
387407 return np .ones (size , dtype = type (self .value )) * self .value
388408
389409 def get_parents (self ):
390- return []
410+ return [] # A Constant does not have any parents
391411
392412 def __repr__ (self ):
393413 return f"{ type (self ).__name__ } ({ self .value } )"
@@ -414,16 +434,19 @@ def __repr__(self):
414434
415435 def _sample (self , q ):
416436 def unpack (arg ):
437+ """Unpack distribution arguments (parents) to arrays if Node."""
417438 return arg .samples_ if isinstance (arg , Node ) else arg
418439
440+ # Parse the arguments and keyword arguments for the distribution
419441 args = tuple (unpack (arg ) for arg in self .args )
420442 kwargs = {k : unpack (v ) for (k , v ) in self .kwargs .items ()}
421443
444+ # Sample from the distribution with inverse CDF
422445 distribution = getattr (stats , self .distr )
423446 return distribution (* args , ** kwargs ).ppf (q )
424447
425448 def get_parents (self ):
426- # A distribution only has parents if its parameters are Nodes
449+ # A distribution only has parents if it has parameters that are Nodes
427450 for arg in self .args + tuple (self .kwargs .values ()):
428451 if isinstance (arg , Node ):
429452 yield arg
@@ -447,6 +470,12 @@ def __repr__(self):
447470
448471
449472class VariadicTransform (Transform ):
473+ """Parent class for variadic transforms (must be associative), e.g.
474+ Add(arg1, arg2, arg3, arg4, ...)
475+ Multiply(arg1, arg2, arg3, arg4, ...)
476+
477+ """
478+
450479 def __init__ (self , * args ):
451480 self .parents = tuple (python_to_prob (arg ) for arg in args )
452481 super ().__init__ ()
@@ -468,6 +497,8 @@ class Multiply(VariadicTransform):
468497
469498
470499class BinaryTransform (Transform ):
500+ """Class for binary transforms, such as Divide, Power, Subtract, etc."""
501+
471502 def __init__ (self , * args ):
472503 self .parents = tuple (python_to_prob (arg ) for arg in args )
473504 super ().__init__ ()
@@ -493,6 +524,9 @@ class Subtract(BinaryTransform):
493524
494525
495526class UnaryTransform (Transform ):
527+ """Class for unary tranforms, i.e. functions that take one argument, such
528+ as Abs(), Exp(), Log()."""
529+
496530 def __init__ (self , arg ):
497531 self .parent = python_to_prob (arg )
498532 super ().__init__ ()
0 commit comments