-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathtest_modeling.py
More file actions
70 lines (53 loc) · 2.1 KB
/
test_modeling.py
File metadata and controls
70 lines (53 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from probabilit.modeling import Constant, Log, Exp
import numpy as np
def test_constant_arithmetic():
# Test that converstion with int works
two = Constant(2)
result = two + 2
np.testing.assert_allclose(result.sample(), 4)
# Test that subtraction works both ways
two = Constant(2)
five = Constant(5)
result1 = five - two
result2 = 5 - two
result3 = five - two
np.testing.assert_allclose(result1.sample(), result2.sample())
np.testing.assert_allclose(result1.sample(), result2.sample())
np.testing.assert_allclose(result1.sample(), result3.sample())
np.testing.assert_allclose(result1.sample(), 5 - 2)
# Test that divison works both ways
two = Constant(2)
five = Constant(5)
result1 = five / two
result2 = 5 / two
result3 = five / two
np.testing.assert_allclose(result1.sample(), result2.sample())
np.testing.assert_allclose(result1.sample(), result2.sample())
np.testing.assert_allclose(result1.sample(), result3.sample())
np.testing.assert_allclose(result1.sample(), 5 / 2)
# Test absolute value and negation
result = abs(-two)
np.testing.assert_allclose(result.sample(), 2)
# Test powers
result = five**two
np.testing.assert_allclose(result.sample(), 5**2)
def test_constant_expressions():
# Test a few longer expressions
two = Constant(2)
five = Constant(5)
result = two + two - five**2 + abs(-five)
np.testing.assert_allclose(result.sample(), 2 + 2 - 5**2 + abs(-5))
result = two / five - two**3 + Exp(5)
np.testing.assert_allclose(result.sample(), 2 / 5 - 2**3 + np.exp(5))
result = 1 / five - (Log(5) + Exp(Log(10)))
np.testing.assert_allclose(result.sample(), 1 / 5 - (np.log(5) + 10))
def test_single_expression():
# A graph with a single node is an edge-case
samples = Constant(2).sample()
np.testing.assert_allclose(samples, 2)
def test_constant_idempotent():
for a in [-1, 0.0, 1.3, 3]:
assert Constant(Constant(a)).value == Constant(a).value
if __name__ == "__main__":
import pytest
pytest.main(args=[__file__, "-v", "--capture=sys"])