-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathgeneralized_hidden_markov_model.py
More file actions
159 lines (126 loc) · 7.09 KB
/
generalized_hidden_markov_model.py
File metadata and controls
159 lines (126 loc) · 7.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from typing import TypeVar, cast
import chex
import equinox as eqx
import jax
import jax.numpy as jnp
from simplexity.generative_processes.generative_process import GenerativeProcess
from simplexity.generative_processes.transition_matrices import stationary_state
State = TypeVar("State", bound=jax.Array)
class GeneralizedHiddenMarkovModel(GenerativeProcess[State]):
"""A Generalized Hidden Markov Model."""
transition_matrices: jax.Array
log_transition_matrices: jax.Array
normalizing_eigenvector: jax.Array
log_normalizing_eigenvector: jax.Array
_initial_state: jax.Array
log_initial_state: jax.Array
normalizing_constant: jax.Array
log_normalizing_constant: jax.Array
def __init__(self, transition_matrices: jax.Array, initial_state: jax.Array | None = None):
self.validate_transition_matrices(transition_matrices)
state_transition_matrix = jnp.sum(transition_matrices, axis=0)
eigenvalues, right_eigenvectors = jnp.linalg.eig(state_transition_matrix)
principal_eigenvalue = jnp.max(eigenvalues)
if jnp.isclose(principal_eigenvalue, 1):
self.transition_matrices = transition_matrices
else:
self.transition_matrices = transition_matrices / principal_eigenvalue
self.log_transition_matrices = jnp.log(transition_matrices)
normalizing_eigenvector = right_eigenvectors[:, jnp.isclose(eigenvalues, principal_eigenvalue)].squeeze().real
self.normalizing_eigenvector = normalizing_eigenvector / jnp.sum(normalizing_eigenvector) * self.num_states
self.log_normalizing_eigenvector = jnp.log(self.normalizing_eigenvector)
if initial_state is None:
initial_state = stationary_state(state_transition_matrix.T)
self._initial_state = initial_state
self.log_initial_state = jnp.log(self._initial_state)
self.normalizing_constant = self._initial_state @ self.normalizing_eigenvector
self.log_normalizing_constant = jax.nn.logsumexp(self.log_initial_state + self.log_normalizing_eigenvector)
def validate_transition_matrices(self, transition_matrices: jax.Array):
"""Validate the transition matrices."""
if transition_matrices.ndim != 3 or transition_matrices.shape[1] != transition_matrices.shape[2]:
raise ValueError("Transition matrices must have shape (vocab_size, num_states, num_states)")
@property
def vocab_size(self) -> int:
"""The number of distinct observations that can be emitted by the model."""
return self.transition_matrices.shape[0]
@property
def num_states(self) -> int:
"""The number of hidden states in the model."""
return self.transition_matrices.shape[1]
@property
def initial_state(self) -> State:
"""The initial state of the model."""
return cast(State, self._initial_state)
@eqx.filter_vmap(in_axes=(None, 0, 0, None))
def generate_with_obs_dist(
self, state: State, key: chex.PRNGKey, sequence_len: int
) -> tuple[State, chex.Array, chex.Array]:
"""Generate a batch of sequences of observations from the generative process.
Inputs:
state: (batch_size, num_states)
key: (batch_size, 2)
Returns: tuple of (belief states, observations, observation probabilities) where:
states: (batch_size, sequence_len, num_states)
obs: (batch_size, sequence_len)
obs_probs: (batch_size, sequence_len, vocab_size)
"""
keys = jax.random.split(key, sequence_len)
def gen_sequences(state: State, key: chex.PRNGKey) -> tuple[State, tuple[State, chex.Array, chex.Array]]:
obs_probs = self.observation_probability_distribution(state)
obs = jax.random.choice(key, self.vocab_size, p=obs_probs)
new_state = self.transition_states(state, obs)
return new_state, (state, obs, obs_probs)
_, (states, obs, obs_probs) = jax.lax.scan(gen_sequences, state, keys)
return states, obs, obs_probs
@eqx.filter_jit
def emit_observation(self, state: State, key: chex.PRNGKey) -> jax.Array:
"""Emit an observation based on the state of the generative process."""
obs_probs = self.observation_probability_distribution(state)
return jax.random.choice(key, self.vocab_size, p=obs_probs)
@eqx.filter_jit
def transition_states(self, state: State, obs: chex.Array) -> State:
"""Evolve the state of the generative process based on the observation.
The input state represents a prior distribution over hidden states, and
the returned state represents a posterior distribution over hidden states
conditioned on the observation.
"""
state = cast(State, state @ self.transition_matrices[obs])
return cast(State, state / (state @ self.normalizing_eigenvector))
@eqx.filter_jit
def normalize_belief_state(self, state: State) -> jax.Array:
"""Compute the probability distribution over states from a state vector.
NOTE: returns nans when state is zeros
"""
return state / (state @ self.normalizing_eigenvector)
@eqx.filter_jit
def normalize_log_belief_state(self, log_belief_state: jax.Array) -> jax.Array:
"""Compute the log probability distribution over states from a log state vector.
NOTE: returns nans when log_belief_state is -infs (state is zeros)
"""
return log_belief_state - jax.nn.logsumexp(log_belief_state + self.log_normalizing_eigenvector)
@eqx.filter_jit
def observation_probability_distribution(self, state: State) -> jax.Array:
"""Compute the probability distribution of the observations that can be emitted by the process."""
return (state @ self.transition_matrices @ self.normalizing_eigenvector) / (
state @ self.normalizing_eigenvector
)
@eqx.filter_jit
def log_observation_probability_distribution(self, log_belief_state: State) -> jax.Array:
"""Compute the log probability distribution of the observations that can be emitted by the process."""
# TODO: fix log math (https://github.com/Astera-org/simplexity/issues/9)
state = cast(State, jnp.exp(log_belief_state))
obs_prob_dist = self.observation_probability_distribution(state)
return jnp.log(obs_prob_dist)
@eqx.filter_jit
def probability(self, observations: jax.Array) -> jax.Array:
"""Compute the probability of the process generating a sequence of observations."""
def _scan_fn(state_vector, observation):
return state_vector @ self.transition_matrices[observation], None
state_vector, _ = jax.lax.scan(_scan_fn, init=self._initial_state, xs=observations)
return (state_vector @ self.normalizing_eigenvector) / self.normalizing_constant
@eqx.filter_jit
def log_probability(self, observations: jax.Array) -> jax.Array:
"""Compute the log probability of the process generating a sequence of observations."""
# TODO: fix log math (https://github.com/Astera-org/simplexity/issues/9)
prob = self.probability(observations)
return jnp.log(prob)