Skip to content

Commit 4309633

Browse files
authored
Floordiv, modulus, consistent ordering (#15)
* floordiv * sort variables in correlations for consistent results * add Mod, i.e. the % operator
1 parent 64193d4 commit 4309633

File tree

1 file changed

+52
-14
lines changed

1 file changed

+52
-14
lines changed

src/probabilit/modeling.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
--------
44
55
Probabilit lets the user perform Monte-Carlo sampling using a high-level
6-
modeling language, which creates a computational graph.
6+
modeling language.
7+
The modeling language creates a lazy computational graph.
8+
When a node is sampled, all ancestor nodes are sampled in turn and the samples are propagated down in the graph, from parent nodes to child nodes.
79
810
For instance, to compute the shipping cost of a box where we are uncertain
911
about the measurements:
@@ -20,21 +22,21 @@
2022
20.00139737515...
2123
2224
Distributions are built on top of scipy, so "norm" refers to the name of the
23-
normal distribution as given in `scipy.stats`, and the arguments to the distribution
24-
must also match those given by `scipy.stats.norm`.
25+
normal distribution as given in `scipy.stats`, and the arguments to the
26+
distribution must also match those given by `scipy.stats.norm`.
2527
26-
Here is another example showing composite distributions, where the argument
27-
to one distribution is another distribution:
28+
Here is another example demonstrating composite distributions, where an
29+
argument to one distribution is another distribution:
2830
2931
>>> eggs_per_nest = Distribution("poisson", mu=3)
3032
>>> survivial_prob = Distribution("beta", a=10, b=15)
3133
>>> survived = Distribution("binom", n=eggs_per_nest, p=survivial_prob)
3234
>>> survived.sample(9, random_state=rng)
3335
array([0., 1., 2., 0., 3., 1., 1., 0., 2.])
3436
35-
To understand and examine the modeling language, we can do some computations
36-
with constants. The computational graph carries out arithmetic operations lazily
37-
once a model is sampled. Mixing numbers with nodes is allowed, but at least one
37+
To understand and examine the modeling language, we can perform computations
38+
using constants. The computational graph carries out arithmetic operations
39+
when the model is sampled. Mixing numbers with nodes is allowed, but at least one
3840
expression or term must be a probabilit class instance:
3941
4042
>>> a = Constant(1)
@@ -62,7 +64,7 @@
6264
Multiply(Distribution("expon", scale=1), Constant(5))
6365
Add(Add(Power(Distribution("norm", loc=5, scale=1), Distribution("expon", scale=1)), Multiply(Distribution("norm", loc=5, scale=1), Distribution("expon", scale=1))), Multiply(Distribution("expon", scale=1), Constant(5)))
6466
65-
Sampling the expression is simple enough:
67+
Sampling the expression is simple:
6668
6769
>>> expression.sample(5, random_state=rng)
6870
array([ 2.70764145, 36.58578812, 7.07064239, 1.84433247, 3.90951632])
@@ -73,10 +75,10 @@
7375
>>> a.samples_
7476
array([4.51589278, 4.37788659, 5.25960812, 5.80609507, 4.33770499])
7577
76-
To sample using e.g. Latin Hypercube, do the following:
78+
To sample using e.g. the Latin Hypercube algorithm, do the following:
7779
7880
>>> from scipy.stats.qmc import LatinHypercube
79-
>>> d = expression.get_number_of_distribution_nodes()
81+
>>> d = expression.num_distribution_nodes()
8082
>>> hypercube = LatinHypercube(d=d, rng=rng)
8183
>>> hypercube_samples = hypercube.random(5) # Draw 5 samples
8284
>>> expression.sample_from_cube(hypercube_samples)
@@ -91,6 +93,19 @@
9193
>>> expression.sample(5, random_state=rng)
9294
array([ 4.70542018, 14.43250192, 6.74494838, -0.14020459, -3.27334554])
9395
96+
Nodes are hashable and can be used in sets, so __hash__ and __eq__ must both
97+
be defined. We cannot use `==` for modeling since equality in that context has
98+
another meaning. Use the Equal node instead. This is only relevant in cases
99+
when equality is part of a model. For real-valued distribution (e.g. Normal)
100+
equality does not make sense since the probability that two floats are equal
101+
is zero.
102+
103+
>>> dice1 = Distribution("uniform", loc=1, scale=6) // 1
104+
>>> dice2 = Distribution("uniform", loc=1, scale=6) // 1
105+
>>> equal_result = Equal(dice1, dice2)
106+
>>> float(equal_result.sample(999, random_state=42).mean())
107+
0.166...
108+
94109
Functions
95110
---------
96111
@@ -328,7 +343,7 @@ def nodes(self):
328343
yield (node := queue.pop())
329344
queue.extend(node.get_parents())
330345

331-
def get_number_of_distribution_nodes(self):
346+
def num_distribution_nodes(self):
332347
return sum(1 for node in set(self.nodes()) if isinstance(node, Distribution))
333348

334349
def sample(self, size=None, random_state=None):
@@ -337,7 +352,7 @@ def sample(self, size=None, random_state=None):
337352
random_state = check_random_state(random_state)
338353

339354
# Draw a cube of random variables in [0, 1]
340-
cube = random_state.random((size, self.get_number_of_distribution_nodes()))
355+
cube = random_state.random((size, self.num_distribution_nodes()))
341356

342357
return self.sample_from_cube(cube)
343358

@@ -346,7 +361,7 @@ def sample_from_cube(self, cube):
346361
distributions. The cube must have shape (dimensionality, num_samples)."""
347362
assert nx.is_directed_acyclic_graph(self.to_graph())
348363
size, n_dim = cube.shape
349-
assert n_dim == self.get_number_of_distribution_nodes()
364+
assert n_dim == self.num_distribution_nodes()
350365

351366
# Prepare columns of quantiles, one column for each Distribution
352367
columns = iter(list(cube.T))
@@ -377,6 +392,7 @@ def sample_from_cube(self, cube):
377392
node.samples_ = node._sample(q=next(columns))
378393

379394
# Go through all ancestor nodes and create a list [(var, corr), ...]
395+
# that contains all correlations we must induce
380396
correlations = []
381397
for node in set(self.nodes()):
382398
if hasattr(node, "_correlations"):
@@ -397,6 +413,8 @@ def sample_from_cube(self, cube):
397413

398414
# Map all variables to integers
399415
all_variables = list(functools.reduce(set.union, variable_sets, set()))
416+
# Ensure consistent ordering for reproducible results
417+
all_variables = sorted(all_variables, key=lambda n: n._id)
400418
var_to_int = {v: i for (i, v) in enumerate(all_variables)}
401419
correlations = [
402420
(tuple(var_to_int[var] for var in variables), corrmat)
@@ -508,12 +526,24 @@ def __mul__(self, other):
508526
def __rmul__(self, other):
509527
return Multiply(self, other)
510528

529+
def __floordiv__(self, other):
530+
return FloorDivide(self, other)
531+
532+
def __rfloordiv__(self, other):
533+
return FloorDivide(other, self)
534+
511535
def __truediv__(self, other):
512536
return Divide(self, other)
513537

514538
def __rtruediv__(self, other):
515539
return Divide(other, self)
516540

541+
def __mod__(self, other):
542+
return Mod(self, other)
543+
544+
def __rmod__(self, other):
545+
return Mod(other, self)
546+
517547
def __sub__(self, other):
518548
return Subtract(self, other)
519549

@@ -691,6 +721,14 @@ def get_parents(self):
691721
yield from self.parents
692722

693723

724+
class FloorDivide(BinaryTransform):
725+
op = np.floor_divide
726+
727+
728+
class Mod(BinaryTransform):
729+
op = np.mod
730+
731+
694732
class Divide(BinaryTransform):
695733
op = operator.truediv
696734

0 commit comments

Comments
 (0)