-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_hidden_markov_model.py
More file actions
173 lines (129 loc) · 7.21 KB
/
test_hidden_markov_model.py
File metadata and controls
173 lines (129 loc) · 7.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import chex
import equinox as eqx
import jax
import jax.numpy as jnp
import pytest
from simplexity.generative_processes.builder import build_hidden_markov_model
from simplexity.generative_processes.hidden_markov_model import HiddenMarkovModel
from tests.assertions import assert_proportional
@pytest.fixture
def z1r() -> HiddenMarkovModel:
return build_hidden_markov_model("zero_one_random", p=0.5)
def test_properties(z1r: HiddenMarkovModel):
assert z1r.vocab_size == 2
assert z1r.num_states == 3
assert_proportional(z1r.normalizing_eigenvector, jnp.ones(3))
assert_proportional(z1r.initial_state, jnp.ones(3))
def test_normalize_belief_state(z1r: HiddenMarkovModel):
state = jnp.array([2, 5, 1])
belief_state = z1r.normalize_belief_state(state)
chex.assert_trees_all_close(belief_state, jnp.array([0.25, 0.625, 0.125]))
state = jnp.array([0, 0, 0])
belief_state = z1r.normalize_belief_state(state)
assert jnp.all(jnp.isnan(belief_state))
def test_normalize_log_belief_state(z1r: HiddenMarkovModel):
state = jnp.log(jnp.array([2, 5, 1]))
log_belief_state = z1r.normalize_log_belief_state(state)
chex.assert_trees_all_close(log_belief_state, jnp.log(jnp.array([0.25, 0.625, 0.125])))
log_belief_state = jnp.array([-jnp.inf, -jnp.inf, -jnp.inf])
log_belief_state = z1r.normalize_log_belief_state(log_belief_state)
assert jnp.all(jnp.isnan(log_belief_state))
def test_single_transition(z1r: HiddenMarkovModel):
zero_state = jnp.array([[1.0, 0.0, 0.0]])
one_state = jnp.array([[0.0, 1.0, 0.0]])
random_state = jnp.array([[0.0, 0.0, 1.0]])
probability = eqx.filter_vmap(z1r.normalize_belief_state)
key = jax.random.PRNGKey(0)[None, :]
single_transition = 1
next_state, observation = z1r.generate(zero_state, key, single_transition, False)
assert_proportional(probability(next_state), one_state)
assert observation == jnp.array(0)
next_state, observation = z1r.generate(one_state, key, single_transition, False)
assert_proportional(probability(next_state), random_state)
assert observation == jnp.array(1)
next_state, observation = z1r.generate(random_state, key, single_transition, False)
assert_proportional(probability(next_state), zero_state)
mixed_state = jnp.array([[0.4, 0.4, 0.2]])
next_state, observation = z1r.generate(mixed_state, key, single_transition, False)
# P(next=0 | obs=x) = P(prev=2 | obs=x)
# P(next=1 | obs=x) = P(prev=0 | obs=x)
# P(next=2 | obs=x) = P(prev=1 | obs=x)
if observation == 0:
# P(obs=0 | prev=2) * P(prev=2) = 0.5 * 0.2 = 0.1
# P(obs=0 | prev=0) * P(prev=0) = 1.0 * 0.4 = 0.4
# P(obs=0 | prev=1) * P(prev=1) = 0.0 * 0.4 = 0.0
next_mixed_state = jnp.array([[0.2, 0.8, 0.0]])
else:
# P(obs=1 | prev=2) * P(prev=2) = 0.5 * 0.2 = 0.1
# P(obs=1 | prev=0) * P(prev=0) = 0.0 * 0.4 = 0.0
# P(obs=1 | prev=1) * P(prev=1) = 1.0 * 0.4 = 0.4
next_mixed_state = jnp.array([[0.2, 0.0, 0.8]])
assert_proportional(probability(next_state), next_mixed_state)
def test_generate(z1r: HiddenMarkovModel):
batch_size = 4
sequence_len = 10
initial_states = jnp.repeat(z1r.normalizing_eigenvector[None, :], batch_size, axis=0)
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
intermediate_states, intermediate_observations = z1r.generate(initial_states, keys, sequence_len, False)
assert intermediate_states.shape == (batch_size, z1r.num_states)
assert intermediate_observations.shape == (batch_size, sequence_len)
keys = jax.random.split(jax.random.PRNGKey(1), batch_size)
final_states, final_observations = z1r.generate(intermediate_states, keys, sequence_len, False)
assert final_states.shape == (batch_size, z1r.num_states)
assert final_observations.shape == (batch_size, sequence_len)
def test_generate_with_intermediate_states(z1r: HiddenMarkovModel):
batch_size = 4
sequence_len = 10
initial_states = jnp.repeat(z1r.normalizing_eigenvector[None, :], batch_size, axis=0)
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
intermediate_states, intermediate_observations = z1r.generate(initial_states, keys, sequence_len, True)
assert intermediate_states.shape == (batch_size, sequence_len, z1r.num_states)
assert intermediate_observations.shape == (batch_size, sequence_len)
last_intermediate_states = intermediate_states[:, -1, :]
final_states, final_observations = z1r.generate(last_intermediate_states, keys, sequence_len, True)
assert final_states.shape == (batch_size, sequence_len, z1r.num_states)
assert final_observations.shape == (batch_size, sequence_len)
def test_generate_with_obs_dist(z1r: HiddenMarkovModel):
batch_size = 4
sequence_len = 10
initial_states = jnp.repeat(z1r.normalizing_eigenvector[None, :], batch_size, axis=0)
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
intermediate_states, intermediate_observations, intermediate_obs_probs = z1r.generate_with_obs_dist(
initial_states, keys, sequence_len
)
assert intermediate_states.shape == (batch_size, sequence_len, z1r.num_states)
assert intermediate_observations.shape == (batch_size, sequence_len)
assert intermediate_obs_probs.shape == (batch_size, sequence_len, z1r.vocab_size)
last_intermediate_states = intermediate_states[:, -1, :]
final_states, final_observations, final_obs_probs = z1r.generate_with_obs_dist(
last_intermediate_states, keys, sequence_len
)
assert final_states.shape == (batch_size, sequence_len, z1r.num_states)
assert final_observations.shape == (batch_size, sequence_len)
assert final_obs_probs.shape == (batch_size, sequence_len, z1r.vocab_size)
def test_observation_probability_distribution(z1r: HiddenMarkovModel):
state = jnp.array([0.3, 0.1, 0.6])
obs_probs = z1r.observation_probability_distribution(state)
chex.assert_trees_all_close(obs_probs, jnp.array([0.6, 0.4]))
state = jnp.array([0.5, 0.3, 0.2])
obs_probs = z1r.observation_probability_distribution(state)
chex.assert_trees_all_close(obs_probs, jnp.array([0.6, 0.4]))
def test_log_observation_probability_distribution(z1r: HiddenMarkovModel):
log_belief_state = jnp.log(jnp.array([0.3, 0.1, 0.6]))
log_obs_probs = z1r.log_observation_probability_distribution(log_belief_state)
assert jnp.isclose(jax.nn.logsumexp(log_obs_probs), 0, atol=1e-7)
chex.assert_trees_all_close(log_obs_probs, jnp.log(jnp.array([0.6, 0.4])))
log_belief_state = jnp.log(jnp.array([0.5, 0.3, 0.2]))
log_obs_probs = z1r.log_observation_probability_distribution(log_belief_state)
assert jnp.isclose(jax.nn.logsumexp(log_obs_probs), 0, atol=1e-7)
chex.assert_trees_all_close(log_obs_probs, jnp.log(jnp.array([0.6, 0.4])))
def test_probability(z1r: HiddenMarkovModel):
observations = jnp.array([1, 0, 0, 1, 1, 0])
expected_probability = 1 / 12
probability = z1r.probability(observations)
assert jnp.isclose(probability, expected_probability)
def test_log_probability(z1r: HiddenMarkovModel):
observations = jnp.array([1, 0, 0, 1, 1, 0])
expected_probability = 1 / 12
log_probability = z1r.log_probability(observations)
assert jnp.isclose(log_probability, jnp.log(expected_probability))