Skip to content

Commit 4f567c7

Browse files
committed
Add shuriken process: 3-state binary nonunifilar HMM
Implements the shuriken process as a new generative process with parameters p, r, u, v. Includes minimality determinant test and default Hydra config. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4075750 commit 4f567c7

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ __pycache__/
77
data/
88
multirun/
99
outputs/
10+
references/
1011
config.ini
1112
*temp*
1213
# MLflow tracking data

simplexity/generative_processes/transition_matrices.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,51 @@ def tom_quantum(alpha: float, beta: float) -> jax.Array:
342342
return transition_matrices
343343

344344

345+
def shuriken(p: float = 0.72, r: float = 0.24, u: float = 0.36, v: float = 0.52) -> jax.Array:
346+
"""Creates transition matrices for the Shuriken Process.
347+
348+
A parameterized family of 3-state binary edge-emitting nonunifilar HMMs.
349+
350+
States: A, B, C
351+
Alphabet: {0, 1}
352+
353+
Symbol-labeled transition matrices where T[x, i, j] = P(next_state=j, emit x | current_state=i):
354+
355+
T0 = [[u, p-u, 0 ], T1 = [[0, 0, 1-p],
356+
[0, v, p-v], [1-p, 0, 0 ],
357+
[r, 0, 0 ]] [0, 1-r, 0 ]]
358+
359+
The generator is minimal (all 3 hidden states are needed) when the minimality determinant
360+
det(M) = -(p - r)^2 * (p - v) is nonzero, i.e. when p != r and p != v. Keeping these
361+
differences large also improves numerical conditioning. The suggested parameter constraints
362+
0 < r < p < 1, 0 < u < p, 0 < v < p
363+
ensure minimality and that all matrix entries are non-negative.
364+
365+
Args:
366+
p: P(emit 0) from states A and B. Also 1 - P(emit 1) from those states.
367+
r: P(emit 0) from state C. Must differ from p for minimality.
368+
u: P(emit 0, stay in A | state A). Controls the A/B split when emitting 0 from A.
369+
v: P(emit 0, stay in B | state B). Must differ from p for minimality.
370+
371+
Returns:
372+
Transition matrices of shape (2, 3, 3).
373+
"""
374+
return jnp.array(
375+
[
376+
[
377+
[u, p - u, 0],
378+
[0, v, p - v],
379+
[r, 0, 0],
380+
],
381+
[
382+
[0, 0, 1 - p],
383+
[1 - p, 0, 0],
384+
[0, 1 - r, 0],
385+
],
386+
]
387+
)
388+
389+
345390
def zero_one_random(p: float) -> jax.Array:
346391
"""Creates a transition matrix for the Zero One Random (Z1R) Process.
347392
@@ -375,6 +420,7 @@ def zero_one_random(p: float) -> jax.Array:
375420
"mr_name": mr_name,
376421
"no_consecutive_ones": no_consecutive_ones,
377422
"rrxor": rrxor,
423+
"shuriken": shuriken,
378424
"sns": sns,
379425
"zero_one_random": zero_one_random,
380426
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
name: shuriken
2+
instance:
3+
_target_: simplexity.generative_processes.builder.build_hidden_markov_model
4+
process_name: shuriken
5+
process_params:
6+
p: 0.72
7+
r: 0.24
8+
u: 0.36
9+
v: 0.52
10+
device: ${device}
11+
12+
base_vocab_size: ???
13+
bos_token: ???
14+
eos_token: null
15+
vocab_size: ???

tests/generative_processes/test_transition_matrices.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
no_consecutive_ones,
1818
post_quantum,
1919
rrxor,
20+
shuriken,
2021
sns,
2122
tom_quantum,
2223
zero_one_random,
@@ -204,6 +205,51 @@ def test_rrxor():
204205
assert jnp.allclose(stationary_distribution, jnp.array([2, 1, 1, 1, 1]) / 6)
205206

206207

208+
def test_shuriken():
209+
"""Test the shuriken transition matrices."""
210+
transition_matrices = shuriken()
211+
assert transition_matrices.shape == (2, 3, 3)
212+
validate_hmm_transition_matrices(transition_matrices)
213+
214+
215+
def test_shuriken_custom_params():
216+
"""Test the shuriken transition matrices with custom parameters."""
217+
transition_matrices = shuriken(p=0.8, r=0.3, u=0.4, v=0.5)
218+
assert transition_matrices.shape == (2, 3, 3)
219+
validate_hmm_transition_matrices(transition_matrices)
220+
221+
222+
def test_shuriken_minimality_determinant():
223+
"""Test that the minimality determinant matches the closed-form expression.
224+
225+
For pure states A=[1,0,0], B=[0,1,0], C=[0,0,1], construct
226+
M = [[1, P(0|A), P(00|A)],
227+
[1, P(0|B), P(00|B)],
228+
[1, P(0|C), P(00|C)]]
229+
and verify |det(M)| = (p-r)^2 * (p-v).
230+
231+
The nonzero determinant confirms that the three pure-state predictive distributions
232+
are linearly independent, establishing that the model is minimal (3 states are needed).
233+
"""
234+
p, r, u, v = 0.72, 0.24, 0.36, 0.52
235+
transition_matrices = shuriken(p=p, r=r, u=u, v=v)
236+
t0 = transition_matrices[0]
237+
238+
pure_states = jnp.eye(3)
239+
p0 = jnp.sum(pure_states @ t0, axis=-1)
240+
next_states = pure_states @ t0
241+
next_states_normalized = next_states / p0[:, None]
242+
p00 = p0 * jnp.sum(next_states_normalized @ t0, axis=-1)
243+
244+
m = jnp.stack([jnp.ones(3), p0, p00], axis=-1)
245+
actual_abs_det = jnp.abs(jnp.linalg.det(m))
246+
expected_abs_det = (p - r) ** 2 * (p - v)
247+
248+
assert jnp.isclose(actual_abs_det, expected_abs_det, atol=1e-6), (
249+
f"Minimality determinant |det(M)|={actual_abs_det} != expected {expected_abs_det}"
250+
)
251+
252+
207253
def test_sns():
208254
"""Test the sns transition matrices."""
209255
transition_matrices = sns(p=0.5, q=0.5)

0 commit comments

Comments
 (0)