Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ __pycache__/
data/
multirun/
outputs/
references/
config.ini
*temp*
# MLflow tracking data
Expand Down
46 changes: 46 additions & 0 deletions simplexity/generative_processes/transition_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,51 @@ def tom_quantum(alpha: float, beta: float) -> jax.Array:
return transition_matrices


def shuriken(p: float = 0.72, r: float = 0.24, u: float = 0.36, v: float = 0.52) -> jax.Array:
"""Creates transition matrices for the Shuriken Process.

A parameterized family of 3-state binary edge-emitting nonunifilar HMMs.

States: A, B, C
Alphabet: {0, 1}

Symbol-labeled transition matrices where T[x, i, j] = P(next_state=j, emit x | current_state=i):

T0 = [[u, p-u, 0 ], T1 = [[0, 0, 1-p],
[0, v, p-v], [1-p, 0, 0 ],
[r, 0, 0 ]] [0, 1-r, 0 ]]

The generator is minimal (all 3 hidden states are needed) when the minimality determinant
det(M) = -(p - r)^2 * (p - v) is nonzero, i.e. when p != r and p != v. Keeping these
differences large also improves numerical conditioning. The suggested parameter constraints
0 < r < p < 1, 0 < u < p, 0 < v < p
ensure minimality and that all matrix entries are non-negative.

Args:
p: P(emit 0) from states A and B. Also 1 - P(emit 1) from those states.
r: P(emit 0) from state C. Must differ from p for minimality.
u: P(emit 0, stay in A | state A). Controls the A/B split when emitting 0 from A.
v: P(emit 0, stay in B | state B). Must differ from p for minimality.

Returns:
Transition matrices of shape (2, 3, 3).
"""
return jnp.array(
[
[
[u, p - u, 0],
[0, v, p - v],
[r, 0, 0],
],
[
[0, 0, 1 - p],
[1 - p, 0, 0],
[0, 1 - r, 0],
],
]
)


def zero_one_random(p: float) -> jax.Array:
"""Creates a transition matrix for the Zero One Random (Z1R) Process.

Expand Down Expand Up @@ -375,6 +420,7 @@ def zero_one_random(p: float) -> jax.Array:
"mr_name": mr_name,
"no_consecutive_ones": no_consecutive_ones,
"rrxor": rrxor,
"shuriken": shuriken,
"sns": sns,
"zero_one_random": zero_one_random,
}
Expand Down
15 changes: 15 additions & 0 deletions tests/end_to_end/configs/generative_process/shuriken.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: shuriken
instance:
_target_: simplexity.generative_processes.builder.build_hidden_markov_model
process_name: shuriken
process_params:
p: 0.72
r: 0.24
u: 0.36
v: 0.52
device: ${device}

base_vocab_size: ???
bos_token: ???
eos_token: null
vocab_size: ???
46 changes: 46 additions & 0 deletions tests/generative_processes/test_transition_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
no_consecutive_ones,
post_quantum,
rrxor,
shuriken,
sns,
tom_quantum,
zero_one_random,
Expand Down Expand Up @@ -204,6 +205,51 @@ def test_rrxor():
assert jnp.allclose(stationary_distribution, jnp.array([2, 1, 1, 1, 1]) / 6)


def test_shuriken():
"""Test the shuriken transition matrices."""
transition_matrices = shuriken()
assert transition_matrices.shape == (2, 3, 3)
validate_hmm_transition_matrices(transition_matrices)


def test_shuriken_custom_params():
"""Test the shuriken transition matrices with custom parameters."""
transition_matrices = shuriken(p=0.8, r=0.3, u=0.4, v=0.5)
assert transition_matrices.shape == (2, 3, 3)
validate_hmm_transition_matrices(transition_matrices)


def test_shuriken_minimality_determinant():
"""Test that the minimality determinant matches the closed-form expression.

For pure states A=[1,0,0], B=[0,1,0], C=[0,0,1], construct
M = [[1, P(0|A), P(00|A)],
[1, P(0|B), P(00|B)],
[1, P(0|C), P(00|C)]]
and verify |det(M)| = (p-r)^2 * (p-v).

The nonzero determinant confirms that the three pure-state predictive distributions
are linearly independent, establishing that the model is minimal (3 states are needed).
"""
p, r, u, v = 0.72, 0.24, 0.36, 0.52
transition_matrices = shuriken(p=p, r=r, u=u, v=v)
t0 = transition_matrices[0]

pure_states = jnp.eye(3)
p0 = jnp.sum(pure_states @ t0, axis=-1)
next_states = pure_states @ t0
next_states_normalized = next_states / p0[:, None]
p00 = p0 * jnp.sum(next_states_normalized @ t0, axis=-1)

m = jnp.stack([jnp.ones(3), p0, p00], axis=-1)
actual_abs_det = jnp.abs(jnp.linalg.det(m))
expected_abs_det = (p - r) ** 2 * (p - v)

assert jnp.isclose(actual_abs_det, expected_abs_det, atol=1e-6), (
f"Minimality determinant |det(M)|={actual_abs_det} != expected {expected_abs_det}"
)


def test_sns():
"""Test the sns transition matrices."""
transition_matrices = sns(p=0.5, q=0.5)
Expand Down
Loading