Skip to content

Commit 4d81128

Browse files
m9hclaude
andcommitted
Bridge HGF, DDM, metacognition, hierarchical to pymdp Agent
All four unique ALF modules now accept pymdp Agents directly: - HGFPerceptualAgent: accepts pymdp Agent (converts to ALF GM internally) - DDM bridge: add neg_efe_to_ddm() for pymdp's sign convention - MetacognitiveAgent: accepts pymdp Agent (converts to ALF GM, wraps) - HierarchicalGenerativeModel.from_pymdp(): creates base level from pymdp Agent with optional higher context levels 9 new tests cover all bridges. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1bcfe9d commit 4d81128

6 files changed

Lines changed: 208 additions & 6 deletions

File tree

alf/ddm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from alf.ddm.bridge import (
3939
efe_to_ddm,
40+
neg_efe_to_ddm,
4041
ddm_to_policy_probs,
4142
)
4243
from alf.ddm.fitting import (
@@ -60,6 +61,7 @@
6061
"simulate_ddm",
6162
# Bridge
6263
"efe_to_ddm",
64+
"neg_efe_to_ddm",
6365
"ddm_to_policy_probs",
6466
# Fitting
6567
"DDMFitResult",

alf/ddm/bridge.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,32 @@ def efe_to_ddm(
5656
return DDMParams(v=v, a=a, w=w, tau=jnp.array(tau))
5757

5858

59+
def neg_efe_to_ddm(
60+
neg_efe: jnp.ndarray,
61+
gamma: float = 4.0,
62+
tau: float = 0.3,
63+
base_a: float = 1.5,
64+
) -> DDMParams:
65+
"""Map pymdp neg_efe values to DDM parameters for a binary choice.
66+
67+
Convenience wrapper for efe_to_ddm that accepts pymdp's neg_efe
68+
convention (higher = better) instead of ALF's G (lower = better).
69+
70+
Args:
71+
neg_efe: Negative EFE for each action, shape (2,) or (1, 2).
72+
Higher values indicate preferred actions.
73+
gamma: Policy precision (inverse temperature).
74+
tau: Non-decision time.
75+
base_a: Base boundary separation.
76+
77+
Returns:
78+
DDMParams with v, a, w, tau.
79+
"""
80+
if neg_efe.ndim == 2:
81+
neg_efe = neg_efe[0]
82+
return efe_to_ddm(-neg_efe, gamma=gamma, tau=tau, base_a=base_a)
83+
84+
5985
def ddm_to_policy_probs(
6086
params: DDMParams,
6187
) -> jnp.ndarray:

alf/hgf/bridge.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ class HGFPerceptualAgent:
115115
116116
Args:
117117
gm: Generative model with B, C matrices for action selection.
118+
Accepts either an ALF GenerativeModel or a pymdp Agent
119+
(which will be converted to ALF GM internally).
118120
hgf_params: HGF parameters (BinaryHGFParams or ContinuousHGFParams).
119121
gamma: Policy precision (inverse temperature). Default 4.0.
120122
state_range: Range for discretizing HGF beliefs.
@@ -123,12 +125,15 @@ class HGFPerceptualAgent:
123125

124126
def __init__(
125127
self,
126-
gm: GenerativeModel,
128+
gm,
127129
hgf_params: BinaryHGFParams | ContinuousHGFParams,
128130
gamma: float = 4.0,
129131
state_range: tuple[float, float] = (-3.0, 3.0),
130132
seed: int = 42,
131133
):
134+
if not isinstance(gm, GenerativeModel):
135+
from alf.compat import pymdp_to_alf
136+
gm = pymdp_to_alf(gm)
132137
self.gm = gm
133138
self.hgf_params = hgf_params
134139
self.gamma = gamma

alf/hierarchical.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,33 @@ class HierarchicalGenerativeModel:
205205
def __init__(self, levels: list[HierarchicalLevel]):
206206
self.levels = list(levels)
207207

208+
@classmethod
209+
def from_pymdp(cls, agent, higher_levels: list[HierarchicalLevel] = None,
210+
level_name: str = "sensorimotor") -> "HierarchicalGenerativeModel":
211+
"""Create a hierarchy with the lowest level from a pymdp Agent.
212+
213+
Extracts the A, B, C, D matrices from a pymdp Agent (stripping the
214+
batch dimension) and uses them as the lowest-level (fastest) model.
215+
Additional higher levels can be appended.
216+
217+
Args:
218+
agent: pymdp.agent.Agent instance.
219+
higher_levels: Optional list of HierarchicalLevel objects for
220+
context/strategy levels above the base sensorimotor level.
221+
level_name: Name for the base level.
222+
223+
Returns:
224+
HierarchicalGenerativeModel with the pymdp model as level 0.
225+
"""
226+
from alf.compat import pymdp_to_alf
227+
gm = pymdp_to_alf(agent)
228+
base_level = HierarchicalLevel(
229+
A=gm.A[0], B=gm.B[0], C=gm.C[0], D=gm.D[0],
230+
temporal_scale=1, level_name=level_name,
231+
)
232+
levels = [base_level] + (higher_levels or [])
233+
return cls(levels)
234+
208235
@property
209236
def num_levels(self) -> int:
210237
return len(self.levels)

alf/metacognition.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -652,10 +652,10 @@ class MetacognitiveAgent:
652652
653653
Args:
654654
inner_agent: The AnalyticAgent to wrap. Can also accept a
655-
GenerativeModel, in which case an AnalyticAgent is created
656-
internally.
655+
GenerativeModel (creates an AnalyticAgent internally) or
656+
a pymdp Agent (converted to ALF GM, then wrapped).
657657
gamma: Initial policy precision. Only used if inner_agent is a
658-
GenerativeModel. Default 4.0.
658+
GenerativeModel or pymdp Agent. Default 4.0.
659659
monitor_decay: EMA decay for the EFEMonitor. Default 0.95.
660660
monitor_window: Window size for the EFEMonitor. Default 50.
661661
gamma_learning_rate: Learning rate for gamma adjustment based
@@ -670,7 +670,7 @@ class MetacognitiveAgent:
670670

671671
def __init__(
672672
self,
673-
inner_agent: AnalyticAgent | GenerativeModel,
673+
inner_agent,
674674
gamma: float = 4.0,
675675
monitor_decay: float = 0.95,
676676
monitor_window: int = 50,
@@ -682,8 +682,13 @@ def __init__(
682682
):
683683
if isinstance(inner_agent, GenerativeModel):
684684
self.agent = AnalyticAgent(inner_agent, gamma=gamma, seed=seed)
685-
else:
685+
elif isinstance(inner_agent, AnalyticAgent):
686686
self.agent = inner_agent
687+
else:
688+
# Assume pymdp Agent — convert to ALF
689+
from alf.compat import pymdp_to_alf
690+
gm = pymdp_to_alf(inner_agent)
691+
self.agent = AnalyticAgent(gm, gamma=gamma, seed=seed)
687692

688693
self.monitor = EFEMonitor(
689694
decay=monitor_decay,

alf/tests/test_pymdp_bridges.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Tests for pymdp bridge integration in HGF, DDM, metacognition, hierarchical."""
2+
3+
import numpy as np
4+
import jax.numpy as jnp
5+
import pytest
6+
7+
try:
8+
from pymdp.agent import Agent as PyMDP_Agent
9+
from alf.compat import alf_to_pymdp
10+
HAS_PYMDP = True
11+
except ImportError:
12+
HAS_PYMDP = False
13+
14+
from alf.generative_model import GenerativeModel
15+
16+
pytestmark = pytest.mark.skipif(not HAS_PYMDP, reason="pymdp not installed")
17+
18+
19+
@pytest.fixture
20+
def simple_model():
21+
"""3-state, 3-obs, 2-action model."""
22+
A = np.array([[0.9, 0.05, 0.05],
23+
[0.05, 0.9, 0.05],
24+
[0.05, 0.05, 0.9]])
25+
B = np.zeros((3, 3, 2))
26+
B[:, :, 0] = np.array([[0.8, 0.1, 0.1],
27+
[0.1, 0.8, 0.1],
28+
[0.1, 0.1, 0.8]])
29+
B[:, :, 1] = np.array([[0.1, 0.1, 0.8],
30+
[0.8, 0.1, 0.1],
31+
[0.1, 0.8, 0.1]])
32+
C = np.array([2.0, 0.0, -1.0])
33+
D = np.array([1/3, 1/3, 1/3])
34+
return GenerativeModel(A=[A], B=[B], C=[C], D=[D])
35+
36+
37+
@pytest.fixture
38+
def pymdp_agent(simple_model):
39+
return alf_to_pymdp(simple_model)
40+
41+
42+
class TestHGFBridgePymdp:
43+
def test_accepts_pymdp_agent(self, pymdp_agent):
44+
from alf.hgf.bridge import HGFPerceptualAgent
45+
from alf.hgf.updates import BinaryHGFParams
46+
47+
params = BinaryHGFParams(omega_2=-2.0, mu_2_0=0.0, sigma_2_0=1.0)
48+
agent = HGFPerceptualAgent(pymdp_agent, params, gamma=4.0)
49+
assert agent.gm is not None
50+
assert isinstance(agent.gm, GenerativeModel)
51+
52+
def test_step_works_with_pymdp(self, pymdp_agent):
53+
from alf.hgf.bridge import HGFPerceptualAgent
54+
from alf.hgf.updates import BinaryHGFParams
55+
56+
params = BinaryHGFParams(omega_2=-2.0, mu_2_0=0.0, sigma_2_0=1.0)
57+
agent = HGFPerceptualAgent(pymdp_agent, params, gamma=4.0)
58+
action, info = agent.step(1.0)
59+
assert isinstance(action, int)
60+
assert "G" in info
61+
62+
63+
class TestDDMBridgePymdp:
64+
def test_neg_efe_to_ddm(self):
65+
from alf.ddm.bridge import neg_efe_to_ddm, efe_to_ddm
66+
67+
G = jnp.array([-1.5, -0.5])
68+
neg_efe = jnp.array([1.5, 0.5])
69+
70+
params_from_G = efe_to_ddm(G)
71+
params_from_neg_efe = neg_efe_to_ddm(neg_efe)
72+
73+
np.testing.assert_allclose(float(params_from_G.v),
74+
float(params_from_neg_efe.v), atol=1e-5)
75+
np.testing.assert_allclose(float(params_from_G.a),
76+
float(params_from_neg_efe.a), atol=1e-5)
77+
78+
def test_neg_efe_to_ddm_with_batch_dim(self):
79+
from alf.ddm.bridge import neg_efe_to_ddm
80+
81+
neg_efe_batched = jnp.array([[1.5, 0.5]])
82+
params = neg_efe_to_ddm(neg_efe_batched)
83+
assert jnp.isfinite(params.v)
84+
85+
86+
class TestMetacognitionPymdp:
87+
def test_accepts_pymdp_agent(self, pymdp_agent):
88+
from alf.metacognition import MetacognitiveAgent
89+
90+
agent = MetacognitiveAgent(pymdp_agent, gamma=4.0)
91+
assert agent.gm is not None
92+
93+
def test_step_works_with_pymdp(self, pymdp_agent):
94+
from alf.metacognition import MetacognitiveAgent
95+
96+
agent = MetacognitiveAgent(pymdp_agent, gamma=4.0)
97+
action, info = agent.step([0])
98+
assert isinstance(action, int)
99+
assert "metacognitive_confidence" in info
100+
assert 0.0 <= info["metacognitive_confidence"] <= 1.0
101+
102+
def test_learn_works_with_pymdp(self, pymdp_agent):
103+
from alf.metacognition import MetacognitiveAgent
104+
105+
agent = MetacognitiveAgent(pymdp_agent, gamma=4.0)
106+
agent.step([0])
107+
agent.learn(1.0)
108+
assert len(agent.accuracy_history) == 1
109+
110+
111+
class TestHierarchicalPymdp:
112+
def test_from_pymdp(self, pymdp_agent):
113+
from alf.hierarchical import HierarchicalGenerativeModel
114+
115+
hierarchy = HierarchicalGenerativeModel.from_pymdp(pymdp_agent)
116+
assert hierarchy.num_levels == 1
117+
assert hierarchy.levels[0].num_states == 3
118+
assert hierarchy.levels[0].num_actions == 2
119+
120+
def test_from_pymdp_with_higher_levels(self, pymdp_agent):
121+
from alf.hierarchical import HierarchicalGenerativeModel, HierarchicalLevel
122+
123+
context_A = np.eye(2)
124+
context_B = np.stack([np.eye(2)] * 2, axis=-1)
125+
context_C = np.array([1.0, 0.0])
126+
context_D = np.array([0.5, 0.5])
127+
context_level = HierarchicalLevel(
128+
A=context_A, B=context_B, C=context_C, D=context_D,
129+
temporal_scale=5, level_name="context",
130+
)
131+
132+
hierarchy = HierarchicalGenerativeModel.from_pymdp(
133+
pymdp_agent, higher_levels=[context_level]
134+
)
135+
assert hierarchy.num_levels == 2
136+
assert hierarchy.levels[0].level_name == "sensorimotor"
137+
assert hierarchy.levels[1].level_name == "context"

0 commit comments

Comments
 (0)