Skip to content

Commit 386f77d

Browse files
committed
rewrote docs a bit
1 parent c9e0a2c commit 386f77d

File tree

2 files changed

+95
-55
lines changed

2 files changed

+95
-55
lines changed

src/probabilit/modeling.py

Lines changed: 89 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,50 @@
33
--------
44
55
Probabilit lets the user perform Monte-Carlo sampling using a high-level
6-
modeling language.
6+
modeling language, which creates a computational graph.
77
8-
As a first look at the modeling language, let us do some computations.
9-
We'll use constants before looking at random variables.
10-
11-
Random samples can be drawn from a node using .sample(), which delegates to scipy:
8+
For instance, to compute the shipping cost of a box where we are uncertain
9+
about the measurements:
1210
1311
>>> rng = np.random.default_rng(42)
14-
>>> a = Constant(1)
15-
>>> a.sample(5, random_state=rng)
16-
array([1, 1, 1, 1, 1])
17-
18-
Computational graphs can be built user overloaded Python operators.
19-
Mixing numbers with nodes is allowed, but at least one expression or term
20-
must be a probabilit class instance:
12+
>>> box_height = Distribution("norm", loc=0.5, scale=0.01)
13+
>>> box_width = Distribution("norm", loc=1, scale=0.01)
14+
>>> box_depth = Distribution("norm", loc=0.8, scale=0.01)
15+
>>> box_volume = box_height * box_width * box_depth
16+
>>> price_per_sqm = 50
17+
>>> price = box_volume * price_per_sqm
18+
>>> samples = price.sample(999, random_state=rng)
19+
>>> float(np.mean(samples))
20+
20.00139737515...
21+
22+
Distributions are built on top of scipy, so "norm" refers to the name of the
23+
normal distribution as given in `scipy.stats`, and the arguments to the distribution
24+
must also match those given by `scipy.stats.norm`.
25+
26+
Here is another example showing composite distributions, where the argument
27+
to one distribution is another distribution:
28+
29+
>>> eggs_per_nest = Distribution("poisson", mu=3)
30+
>>> survivial_prob = Distribution("beta", a=10, b=15)
31+
>>> survived = Distribution("binom", n=eggs_per_nest, p=survivial_prob)
32+
>>> survived.sample(9, random_state=rng)
33+
array([1., 1., 1., 0., 2., 1., 1., 0., 2.])
34+
35+
To understand and examine the modeling language, we can do some computations
36+
with constants. The computational graph carries out arithmetic operations lazily
37+
once a model is sampled. Mixing numbers with nodes is allowed, but at least one
38+
expression or term must be a probabilit class instance:
2139
40+
>>> a = Constant(1)
2241
>>> (a * 3 + 5).sample(5, random_state=rng)
2342
array([8, 8, 8, 8, 8])
2443
>>> Add(10, 5, 5).sample(5, random_state=rng)
2544
array([20, 20, 20, 20, 20])
2645
27-
Of course, things get more interesting with probability distributions.
28-
The names and arguments correspond to scipy distributions (scipy.stats).
46+
Let us build a more compliated expression:
2947
3048
>>> a = Distribution("norm", loc=5, scale=1)
3149
>>> b = Distribution("expon", scale=1)
32-
>>> product = a * b
33-
34-
The product above is not evaluated untill we sample from it.
35-
36-
>>> product.sample(5, random_state=rng)
37-
array([ 3.32357208, 7.25992397, 13.68470082, 8.80523473, 2.31314151])
38-
39-
Let us build a more compliated expression:
40-
4150
>>> expression = a**b + a * b + 5 * b
4251
4352
Every unique node in this expression can be found:
@@ -56,38 +65,38 @@
5665
Sampling the expression is simple enough:
5766
5867
>>> expression.sample(5, random_state=rng)
59-
array([81.47571166, 36.25874807, 4.04413643, 1.78245506, 16.86301139])
68+
array([ 2.70764145, 36.58578812, 7.07064239, 1.84433247, 3.90951632])
6069
6170
Sampling the expression has the side effect that `.samples_` is populated on
6271
*every* node in the expression, for instance:
6372
6473
>>> a.samples_
65-
array([4.6702595 , 5.36880945, 4.85768145, 5.1372535 , 5.9448457 ])
74+
array([4.51589278, 4.37788659, 5.25960812, 5.80609507, 4.33770499])
6675
6776
To sample using e.g. Latin Hypercube, do the following:
6877
6978
>>> from scipy.stats.qmc import LatinHypercube
70-
>>> d = expression.get_dimensionality()
79+
>>> d = expression.get_number_of_distribution_nodes()
7180
>>> hypercube = LatinHypercube(d=d, rng=rng)
7281
>>> hypercube_samples = hypercube.random(5) # Draw 5 samples
7382
>>> expression.sample_from_cube(hypercube_samples)
7483
array([ 1.20438726, 12.40283222, 5.02130766, 16.45109076, 77.12874028])
7584
76-
Here is a more complex expression:
85+
Here is an even more complex expression:
7786
7887
>>> a = Distribution("norm", loc=0, scale=1)
7988
>>> b = Distribution("norm", loc=0, scale=2)
8089
>>> c = Distribution("norm", loc=0, scale=3)
8190
>>> expression = a*a - Add(a, b, c) + Abs(b)**Abs(c) + Exp(1 / Abs(c))
8291
>>> expression.sample(5, random_state=rng)
83-
array([-3.75434563, 5.84160178, 50.58877597, -1.32687877, 81.00831756])
84-
92+
array([ 4.70542018, 14.43250192, 6.74494838, -0.14020459, -3.27334554])
8593
8694
Functions
95+
---------
8796
8897
If you have a function that is not an arithmetic expression, you can still
8998
Monte-Carlo simulate through it with the `scalar_transform` decorator, which
90-
will pass each sample through the computation node in a loop when we sample:
99+
will pass each sample through the computation node in a loop:
91100
92101
>>> def function(a, b):
93102
... if a > 0:
@@ -100,19 +109,18 @@
100109
101110
>>> a = Distribution("norm", loc=0, scale=1)
102111
>>> b = Distribution("norm", loc=0, scale=2)
103-
>>> expression = function(a, b) # Function is not called here
112+
>>> expression = function(a, b) # Function is not actually called here
104113
105114
Now sample 'through' the function:
106115
107116
>>> expression.sample(5, random_state=rng)
108-
array([ 0. , 0. , -0.13902087, 1.01335768, 0. ])
117+
array([0. , 0. , 0.45555522, 0. , 0. ])
109118
"""
110119

111120
import operator
112121
import functools
113122
import numpy as np
114123
import numbers
115-
import dataclasses
116124
from scipy import stats
117125
import abc
118126
import itertools
@@ -204,55 +212,68 @@ def python_to_prob(argument):
204212
# is determined by the graph structure.
205213

206214

207-
@dataclasses.dataclass
208215
class Node(abc.ABC):
209216
"""A node in the computational graph."""
210217

211-
id_iter = itertools.count() # Everyone gets a unique ID
218+
id_iter = itertools.count() # Every node gets a unique ID
219+
220+
def __init__(self):
221+
self._id = next(self.id_iter)
212222

213223
def __eq__(self, other):
214224
return self._id == other._id
215225

216226
def __hash__(self):
217227
return self._id
218228

219-
def __post_init__(self):
220-
self._id = next(self.id_iter)
221-
222229
def nodes(self):
223-
"""Yields all ancestors using depth-first-search, including `self`."""
230+
"""Yields `self` and all ancestors using depth-first-search.
231+
232+
Examples
233+
--------
234+
>>> expression = Distribution("norm") - 2**Constant(2)
235+
>>> for node in expression.nodes():
236+
... print(node)
237+
Subtract(Distribution("norm"), Power(Constant(2), Constant(2)))
238+
Power(Constant(2), Constant(2))
239+
Constant(2)
240+
Constant(2)
241+
Distribution("norm")
242+
"""
224243
queue = [(self)]
225244
while queue:
226245
yield (node := queue.pop())
227246
queue.extend(node.get_parents())
228247

229-
def get_dimensionality(self):
248+
def get_number_of_distribution_nodes(self):
230249
return sum(1 for node in set(self.nodes()) if isinstance(node, Distribution))
231250

232251
def sample(self, size=None, random_state=None):
233-
"""Assign samples to self.samples_ rescursively."""
252+
"""Sample the current node and assign to all node.samples_."""
234253
size = 1 if size is None else size
235254
random_state = np.random.default_rng(random_state)
236255

237256
# Draw a cube of random variables in [0, 1]
238-
cube = random_state.random((size, self.get_dimensionality()))
257+
cube = random_state.random((size, self.get_number_of_distribution_nodes()))
239258

240259
return self.sample_from_cube(cube)
241260

242261
def sample_from_cube(self, cube):
243-
"""Use samples from a cube of shape (dimensionality, num_samples)."""
262+
"""Use samples from a cube of quantiles in [0, 1] to sample all
263+
distributions. The cube must have shape (dimensionality, num_samples)."""
244264
assert nx.is_directed_acyclic_graph(self.to_graph())
245-
246265
size, n_dim = cube.shape
247-
assert n_dim == self.get_dimensionality()
266+
assert n_dim == self.get_number_of_distribution_nodes()
267+
268+
# Prepare columns of quantiles, one column for each Distribution
248269
columns = iter(list(cube.T))
249270

250-
# Clear any samples that might exist
271+
# Clear any samples that might exist in the graph
251272
for node in set(self.nodes()):
252273
if hasattr(node, "samples_"):
253274
delattr(node, "samples_")
254275

255-
# Sample leaf nodes that are distributions first
276+
# Start with initial sampling nodes, which contain independent variables
256277
initial_sampling_nodes = [
257278
node for node in set(self.nodes()) if node._is_initial_sampling_node()
258279
]
@@ -272,11 +293,8 @@ def sample_from_cube(self, cube):
272293

273294
# TODO: correlate the samples
274295

275-
# Iterate over the remaining nodes and sample
276-
remaining_nodes = nx.topological_sort(G)
277-
278296
# Iterate from leaf nodes and up to parent
279-
for node in remaining_nodes:
297+
for node in nx.topological_sort(G):
280298
if hasattr(node, "samples_"):
281299
continue
282300
elif isinstance(node, Constant):
@@ -285,6 +303,8 @@ def sample_from_cube(self, cube):
285303
node.samples_ = node._sample(q=next(columns))
286304
elif isinstance(node, Transform):
287305
node.samples_ = node._sample()
306+
else:
307+
raise TypeError("Node must be Constant, Distribution or Transform.")
288308

289309
return self.samples_
290310

@@ -333,7 +353,7 @@ def to_graph(self):
333353

334354

335355
class OverloadMixin:
336-
"""Overloads dunder (double underscore) methods."""
356+
"""Overloads dunder (double underscore) methods for easier modeling."""
337357

338358
def __add__(self, other):
339359
return Add(self, other)
@@ -375,7 +395,7 @@ def __abs__(self):
375395
class Constant(Node, OverloadMixin):
376396
"""A constant is a number."""
377397

378-
is_leaf = True
398+
is_leaf = True # A Constant is always a leaf node
379399

380400
def __init__(self, value):
381401
self.value = value
@@ -387,7 +407,7 @@ def _sample(self, size=None, random_state=None):
387407
return np.ones(size, dtype=type(self.value)) * self.value
388408

389409
def get_parents(self):
390-
return []
410+
return [] # A Constant does not have any parents
391411

392412
def __repr__(self):
393413
return f"{type(self).__name__}({self.value})"
@@ -414,16 +434,19 @@ def __repr__(self):
414434

415435
def _sample(self, q):
416436
def unpack(arg):
437+
"""Unpack distribution arguments (parents) to arrays if Node."""
417438
return arg.samples_ if isinstance(arg, Node) else arg
418439

440+
# Parse the arguments and keyword arguments for the distribution
419441
args = tuple(unpack(arg) for arg in self.args)
420442
kwargs = {k: unpack(v) for (k, v) in self.kwargs.items()}
421443

444+
# Sample from the distribution with inverse CDF
422445
distribution = getattr(stats, self.distr)
423446
return distribution(*args, **kwargs).ppf(q)
424447

425448
def get_parents(self):
426-
# A distribution only has parents if its parameters are Nodes
449+
# A distribution only has parents if it has parameters that are Nodes
427450
for arg in self.args + tuple(self.kwargs.values()):
428451
if isinstance(arg, Node):
429452
yield arg
@@ -447,6 +470,12 @@ def __repr__(self):
447470

448471

449472
class VariadicTransform(Transform):
473+
"""Parent class for variadic transforms (must be associative), e.g.
474+
Add(arg1, arg2, arg3, arg4, ...)
475+
Multiply(arg1, arg2, arg3, arg4, ...)
476+
477+
"""
478+
450479
def __init__(self, *args):
451480
self.parents = tuple(python_to_prob(arg) for arg in args)
452481
super().__init__()
@@ -468,6 +497,8 @@ class Multiply(VariadicTransform):
468497

469498

470499
class BinaryTransform(Transform):
500+
"""Class for binary transforms, such as Divide, Power, Subtract, etc."""
501+
471502
def __init__(self, *args):
472503
self.parents = tuple(python_to_prob(arg) for arg in args)
473504
super().__init__()
@@ -493,6 +524,9 @@ class Subtract(BinaryTransform):
493524

494525

495526
class UnaryTransform(Transform):
527+
"""Class for unary tranforms, i.e. functions that take one argument, such
528+
as Abs(), Exp(), Log()."""
529+
496530
def __init__(self, arg):
497531
self.parent = python_to_prob(arg)
498532
super().__init__()

tests/test_modeling.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ def test_constant_expressions():
5353
np.testing.assert_allclose(result.sample(), 1 / 5 - (np.log(5) + 10))
5454

5555

56+
def test_single_expression():
57+
# A graph with a single node is an edge-case
58+
samples = Constant(2).sample()
59+
np.testing.assert_allclose(samples, 2)
60+
61+
5662
if __name__ == "__main__":
5763
import pytest
5864

0 commit comments

Comments
 (0)