GNN Integration Layer: Python / JAX-based Probabilistic Programming Framework Base:
numpyro >= 0.14(NumPy-interface Pyro with JAX backend) Simulation Architecture: Inference-as-sampling POMDP agent Documentation Version: 2.0.0
NumPyro provides a probabilistic programming backend for GNN models. Unlike PyMDP (which uses fixed-point belief updates) or JAX (which uses manual message-passing), NumPyro treats the generative model as a probabilistic program and uses MCMC or SVI (Stochastic Variational Inference) for posterior inference.
This unlocks uncertainty quantification beyond the classical Dirichlet-categorical parameterisation — GNN-specified matrices become priors over distributions, not point estimates.
| Stage | Module | Description |
|---|---|---|
| Rendering (Step 11) | src/render/numpyro/numpyro_renderer.py |
GNN JSON → NumPyro probabilistic program |
| Execution (Step 12) | src/execute/numpyro/numpyro_runner.py |
MCMC/SVI inference, log persistence |
| Analysis (Step 16) | src/analysis/numpyro/analyzer.py |
Posterior summaries, uncertainty bands |
import numpyro
import numpyro.distributions as dist
from jax import random
def gnn_generative_model(A, B, C, D, T=20):
"""NumPyro probabilistic program for GNN POMDP."""
# Prior over initial hidden state
initial_state_dist = dist.Categorical(probs=D)
true_state = numpyro.sample("s_0", initial_state_dist)
observations = []
for t in range(T):
# Observation model: P(o | s)
obs = numpyro.sample(f"o_{t}", dist.Categorical(probs=A[:, true_state]))
observations.append(obs)
if t < T - 1:
# Expected Free Energy–based action selection
# (simplified: uniform prior over actions for unconditioned sampling)
action = numpyro.sample(f"a_{t}", dist.Categorical(probs=C / C.sum()))
# Transition model: P(s' | s, a)
true_state = numpyro.sample(
f"s_{t+1}",
dist.Categorical(probs=B[:, true_state, action])
)
return observationsGNN matrices become NumPyro concentration parameters:
import jax.numpy as jnp
# A matrix: Dirichlet prior over likelihood columns
A_prior = A + 1e-6 # stabilize zeros
A_dist = dist.Dirichlet(concentration=A_prior.T * 10.0)
A_sample = numpyro.sample("A", A_dist)
# D vector: Dirichlet prior over initial state
D_dist = dist.Dirichlet(concentration=D * 10.0)
D_sample = numpyro.sample("D", D_dist)from numpyro.infer import MCMC, NUTS
# NUTS (No-U-Turn Sampler)
nuts_kernel = NUTS(gnn_generative_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(random.PRNGKey(0), A=A_matrix, B=B_matrix, C=C_vector, D=D_vector)
# Posterior samples
posterior = mcmc.get_samples()
# {"A": shape [1000, num_obs, num_states], "D": shape [1000, num_states], ...}{
"inference": {
"method": "NUTS",
"num_samples": 1000,
"num_warmup": 500,
"divergences": 0,
"r_hat_max": 1.01
},
"simulation_trace": {
"observations": [1, 0, 2, 1, 1],
"beliefs": [[0.05, 0.90, 0.05], ...],
"actions": [0, 0, 1, 0, 0],
"efe_history": [[-1.2, -0.8, -2.1], ...],
"posterior_uncertainty": {
"A_std": 0.03,
"D_std": 0.08
}
}
}pip install numpyro jax jaxlib
# GPU (CUDA):
pip install numpyro "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlCheck availability:
PYTHONPATH=src python -c "from execute.numpyro import check_numpyro; print(check_numpyro())"# Render GNN to NumPyro script
python src/11_render.py --target-dir input/gnn_files/ --framework numpyro
# Execute with MCMC inference
python src/12_execute.py --target-dir input/gnn_files/ --framework numpyro
# Or use MCP tool
# Call: execute_gnn_model(path="...", framework="numpyro")| Feature | PyMDP | JAX | NumPyro |
|---|---|---|---|
| Inference method | Belief propagation | Manual message-passing | MCMC / SVI |
| Uncertainty output | None | None | Full posterior |
| GPU support | No | Yes | Yes |
| Speed (20 steps) | Fast | Fast | Slow (sampling) |
| Use case | Classical POMDP | High-perf POMDP | Uncertainty quantification |
During the March 6, 2026 pipeline benchmarking audit, NumPyro was verified as Fully Operational. NumPyro's posterior mean beliefs correlate with PyMDP and JAX beliefs at ~1.0 for deterministic GNN matrices. Posterior variance is the unique contribution, supplying rich uncertainty mechanics while maintaining cross-framework fidelity.
| Stage | Module | Key Function |
|---|---|---|
| Rendering | numpyro_renderer.py | render_gnn_to_numpyro() |
| Execution | numpyro_runner.py | execute_numpyro_script() |
| Analysis | analyzer.py | generate_analysis_from_logs() |
| ID | Area | Description | Impact |
|---|---|---|---|
| NP-1 | Rendering | ✅ FIXED | |
| NP-2 | Execution | SVI not yet implemented (only NUTS) | Medium |
| NP-3 | Analysis | Posterior uncertainty plots not linked to Step 9 (Adv Viz) | Low |
- Cross-Framework Methodology: Details on the correlation methodology and benchmarking metrics.
- Architecture Reference: Deep dive into the pipeline orchestrator and module integration.
- GNN Implementations Index: Return to the master framework implementer manifest.
- Back to GNN START_HERE