FabricPC's built-in nodes cover many common use cases, but you may need custom nodes for specialized computations. This guide walks through creating your own node types.
Built-in nodes include:
Linear: Fully-connected layersIdentityNode: Passthrough nodesStorkeyHopfield: Associative memoryTransformerBlock: Multi-head attention and feedforward
Create a custom node when you need:
- Gating mechanisms (LSTM, GRU)
- Custom transfer functions
- Specialized domain-specific projections
- Any computation not covered by built-in nodes
All nodes inherit from NodeBase and must implement three static methods:
get_slots(): Define input slots (the connection interface)initialize_params(): Allocate and initialize weights and biasesforward(): Compute predictions, errors, and energy
These methods are static because FabricPC uses a functional JAX-based design. Node instances hold configuration; the static methods define pure functional transformations.
forward() is where the node does its real work, and its body is intentionally unconstrained: how you turn inputs into a prediction is up to you. But a fixed set of steps must happen inside it for the node to participate in inference and learning — these are spelled out in Implement Forward Computation below. The split is:
- Required (every node): produce
z_mu, recordpre_activation, computeerror, write those fields back, populateenergyvia the energy functional, and return(total_energy, state). - Flexible (per node): how inputs are combined (sum, matmul, attention, embedding lookup), whether weights/biases exist, whether and which activation applies, any internal sub-structure (LayerNorm, attention, residual paths), and any extra energy terms.
Let's build a 2D convolutional node from scratch, to illustrate the node contract. (FabricPC now ships a production ConvNode in fabricpc.nodes; this from-scratch version is for teaching, not for use.)
import jax.numpy as jnp
import numpy as np
from fabricpc.nodes.base import NodeBase, SlotSpec
from fabricpc.core.types import NodeParams, NodeState, NodeInfo
from fabricpc.core.activations import ReLUActivation
from fabricpc.core.energy import GaussianEnergy
from fabricpc.core.initializers import NormalInitializer, initialize
class Conv2DNode(NodeBase):
def __init__(
self,
shape,
name,
kernel_size,
stride=(1, 1),
padding="SAME",
activation=ReLUActivation(),
energy=GaussianEnergy(),
latent_init=NormalInitializer(),
weight_init=NormalInitializer(),
**kwargs
):
super().__init__(
shape=shape,
name=name,
activation=activation,
energy=energy,
latent_init=latent_init,
weight_init=weight_init,
kernel_size=kernel_size,
stride=stride,
padding=padding,
**kwargs
)Key points:
- Accept standard node parameters (
shape,name,activation, etc.) - Accept custom parameters (
kernel_size,stride,padding) - Pass everything to
super().__init__()via**kwargs - Custom parameters end up in
node_info.node_configand are accessible in static methods
@staticmethod
def get_slots():
return {"in": SlotSpec(name="in", is_multi_input=True)}This defines a single input slot named "in" that accepts multiple incoming edges. The is_multi_input=True flag means contributions from different source nodes will be summed.
For nodes with multiple distinct inputs:
@staticmethod
def get_slots():
return {
"in": SlotSpec(name="in", is_multi_input=True),
"mask": SlotSpec(name="mask", is_multi_input=False),
}For muPC scaling, override get_weight_fan_in() to return the correct fan-in:
@staticmethod
def get_weight_fan_in(source_shape, config):
kernel_size = config.get("kernel_size", (1, 1))
C_in = source_shape[-1] # channels-last format
return C_in * int(np.prod(kernel_size))For a 3x3 kernel with 16 input channels: fan_in = 16 * 3 * 3 = 144.
If you don't override this method, the default implementation uses the flattened source shape, which works for fully-connected layers but not for convolutions.
@staticmethod
def initialize_params(key, node_shape, input_shapes, weight_init=None, config=None):
"""
Initialize convolutional kernels and biases.
Args:
key: JAX random key
node_shape: Output shape of this node (H, W, C_out)
input_shapes: Dict mapping edge_key -> source_shape
weight_init: Weight initializer
config: Node configuration dict
Returns:
NodeParams(weights=weights_dict, biases=biases_dict)
"""
kernel_size = config.get("kernel_size", (1, 1))
C_out = node_shape[-1] # output channels
weights = {}
biases = {}
for edge_key, source_shape in input_shapes.items():
C_in = source_shape[-1] # input channels
# Kernel shape: (kH, kW, C_in, C_out)
kernel_shape = (*kernel_size, C_in, C_out)
# Initialize kernel weights
subkey, key = jax.random.split(key)
fan_in = C_in * int(np.prod(kernel_size))
fan_out = C_out * int(np.prod(kernel_size))
weights[edge_key] = initialize(
subkey,
kernel_shape,
weight_init,
fan_in=fan_in,
fan_out=fan_out,
)
# Bias shape: (1, 1, 1, C_out) for broadcasting
biases[edge_key] = jnp.zeros((1, 1, 1, C_out))
return NodeParams(weights=weights, biases=biases)Key points:
input_shapesis a dict keyed by edge identifiers (e.g.,"conv1->conv2:in")- Each edge gets its own weight matrix and bias vector
- Use the
initialize()helper function with the providedweight_initinitializer - Return
NodeParamswith dicts of weights and biases
@staticmethod
def forward(params, inputs, state, node_info):
"""
Forward pass: compute convolution, activation, error, and energy.
Args:
params: NodeParams with weights and biases dicts
inputs: Dict mapping edge_key -> input_array
state: Current NodeState
node_info: NodeInfo with configuration
Returns:
(total_energy, updated_state)
"""
# Extract config
kernel_size = node_info.node_config.get("kernel_size", (1, 1))
stride = node_info.node_config.get("stride", (1, 1))
padding = node_info.node_config.get("padding", "SAME")
activation = node_info.activation
# Accumulate convolution outputs from all input edges
pre_activation = None
for edge_key, input_array in inputs.items():
kernel = params.weights[edge_key]
bias = params.biases[edge_key]
# Perform convolution
# JAX uses (batch, H, W, C) format
conv_out = jax.lax.conv_general_dilated(
lhs=input_array,
rhs=kernel,
window_strides=stride,
padding=padding,
dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
)
conv_out = conv_out + bias
if pre_activation is None:
pre_activation = conv_out
else:
pre_activation = pre_activation + conv_out
# Apply activation function
z_mu = type(activation).forward(pre_activation, activation.config)
# Compute prediction error
error = state.z_latent - z_mu
# Update state
state = state._replace(
pre_activation=pre_activation,
z_mu=z_mu,
error=error,
)
# Compute energy using the energy functional
node_class = node_info.node_class
state = node_class.energy_functional(state, node_info)
# Return total energy and updated state
total_energy = jnp.sum(state.energy)
return total_energy, stateforward() is a pure function. It must have no side effects and must express its dependence on params, inputs, and state.z_latent entirely through JAX operations, because the framework differentiates it under jax.value_and_grad — forward_and_latent_grads differentiates it with respect to inputs and z_latent, and forward_and_weight_grads with respect to params. Side effects or Python-level control flow on traced values produce wrong gradients during inference and learning.
Within that constraint, every forward() must perform these six steps in order:
- Predict
z_mu: produce the node's prediction of its own latent, with shape(batch,) + node_info.shape. How is up to the node — a convolution here, a matmul inLinear, an attention pipeline inTransformerBlock. - Record
pre_activation: the value before the activation function. If the node applies no activation, setpre_activation = z_mu.pre_activationis planned for deprecation as persistent attribute of NodeState; it's actually an ephemeral intermediate toz_mu. - Compute the error:
error = state.z_latent - z_mu. The energy functionals assume this sign (latent minus prediction). - Write the fields back:
state = state._replace(z_mu=..., pre_activation=..., error=...).NodeStateis a fixed-schema NamedTuple (z_latent, z_mu, error, energy, pre_activation, latent_grad); no other fields exist or may be added. - Populate energy:
node_class = node_info.node_class; state = node_class.energy_functional(state, node_info). This setsstate.energyfromenergy(z_latent, z_mu), soz_mumust already be set. Extra energy terms (for example the Hopfield attractor term inStorkeyHopfield) are added by replacingstate.energyafter this call. - Return:
return jnp.sum(state.energy), state— the scalar total energy first, the updated state second.
The steps between predicting z_mu and writing it back are free: input aggregation, weights and biases, the choice of activation, and any internal sub-structure are all node-specific.
muPC scaling is not applied inside
forward(). The inference and learning callsites scale inputs and gradients; doing so again here double-scales them. See the Initialization and Scaling guide.
from fabricpc.core.topology import Edge
from fabricpc.graph_assembly import graph, TaskMap
# Create nodes
input_node = IdentityNode(shape=(28, 28, 1), name="input")
conv1 = Conv2DNode(
shape=(26, 26, 16), # VALID padding: 28-3+1 = 26
kernel_size=(3, 3),
stride=(1, 1),
padding="VALID",
name="conv1",
)
conv2 = Conv2DNode(
shape=(24, 24, 32),
kernel_size=(3, 3),
stride=(1, 1),
padding="VALID",
name="conv2",
)
output_node = Linear(
shape=(10,),
flatten_input=True,
name="output",
)
# Build graph
structure = graph(
nodes=[input_node, conv1, conv2, output_node],
edges=[
Edge(source=input_node, target=conv1.slot("in")),
Edge(source=conv1, target=conv2.slot("in")),
Edge(source=conv2, target=output_node.slot("in")),
],
task_map=TaskMap(x=input_node, y=output_node),
inference=InferenceSGD(eta_infer=0.05, infer_steps=20),
scaling=MuPCConfig(),
)For nodes that need dense (fully-connected) behavior with flattened inputs, use the FlattenInputMixin:
from fabricpc.nodes.base import FlattenInputMixin
class MyDenseNode(FlattenInputMixin, NodeBase):
@staticmethod
def forward(params, inputs, state, node_info):
batch_size = state.z_latent.shape[0]
out_shape = node_info.shape
# Sum (flattened_input @ weight) over all input edges -> (batch, *out_shape)
pre_activation = MyDenseNode.compute_linear(
inputs, params.weights, batch_size, out_shape
)
# Add bias, if the node has one
if "b" in params.biases and params.biases["b"].size > 0:
pre_activation = pre_activation + params.biases["b"]
# Apply activation
z_mu = type(node_info.activation).forward(
pre_activation, node_info.activation.config
)
# ... then compute error, update state, populate energy, and return
# (the six required steps above)The mixin provides:
flatten_input(x): Flattens one input tensor from(batch, *shape)to(batch, numel)reshape_output(x_flat, out_shape): Reshapes(batch, numel)back to(batch, *out_shape)compute_linear(inputs, weights, batch_size, out_shape): Sumsflattened_input @ weightacross all input edges and reshapes to(batch, *out_shape). It does not add a bias — add it yourself, as shown above.
By default, FabricPC computes gradients with JAX autodiff: it differentiates your forward() to obtain both the latent gradients (inference) and the weight gradients (learning). For hand-coded gradients (e.g. for efficiency or control), override forward_and_latent_grads() and forward_and_weight_grads(). These return gradients, not energy, so their signatures differ from forward():
class MyNode(NodeBase):
@staticmethod
def forward_and_latent_grads(params, inputs, state, node_info, is_clamped):
# Run forward() for the updated state, then compute gradients analytically.
node_class = node_info.node_class
_, state = node_class.forward(params, inputs, state, node_info)
# Self-latent gradient dE/dz_latent via the energy functional
energy_obj = node_info.energy
self_grad = type(energy_obj).grad_latent(
state.z_latent, state.z_mu, energy_obj.config
)
# Per-edge input gradients dE/d_input (uses activation.derivative())
input_grads = {...}
# Returns (updated_state, input_grads, self_grad)
return state, input_grads, self_grad
@staticmethod
def forward_and_weight_grads(params, inputs, state, node_info):
node_class = node_info.node_class
_, state = node_class.forward(params, inputs, state, node_info)
# Compute weight/bias gradients analytically ...
# Returns (updated_state, NodeParams(weights=..., biases=...))
return state, NodeParams(weights=weight_grads, biases=bias_grads)Note that muPC scaling and accumulation into state.latent_grad are handled by the callsite, not inside these overrides. See LinearExplicitGrad (fabricpc/nodes/linear_explicit_grad.py) for a complete example, including its compute_gain_mod_error() helper that combines state.error with activation.derivative().
For nodes with distinct input types (e.g., data and mask):
@staticmethod
def get_slots():
return {
"in": SlotSpec(name="in", is_multi_input=True),
"mask": SlotSpec(name="mask", is_multi_input=False),
}
@staticmethod
def forward(params, inputs, state, node_info):
# Access inputs by slot
data_inputs = {k: v for k, v in inputs.items() if k.endswith(":in")}
mask_inputs = {k: v for k, v in inputs.items() if k.endswith(":mask")}
# Process separately
# ...NodeState is a fixed-schema NamedTuple with exactly these fields:
class NodeState(NamedTuple):
z_latent: jnp.ndarray # latent states (what the network infers)
z_mu: jnp.ndarray # predictions (what the network predicts)
error: jnp.ndarray # prediction errors (z_latent - z_mu)
energy: jnp.ndarray # per-sample energy, shape (batch,)
pre_activation: jnp.ndarray # values before the activation function
latent_grad: jnp.ndarray # gradient accumulator for inference updatesYou cannot add custom fields to it. state._replace(...) only updates these existing fields, so there is no place to stash arbitrary per-step memory (such as an RNN hidden vector) on the NodeState.
If your node needs additional state, route it through what already exists:
- Latent memory: anything the node should "remember" and infer over belongs in
z_latent. The inference loop already carriesz_latentacross steps and updates it from the energy gradient. - Inputs: state supplied by other nodes arrives through
inputs, keyed by edge. Add an input slot (see Multiple Input Slots) to receive it. - Parameters: fixed-per-graph quantities belong in
params.weights/params.biases, allocated ininitialize_params().
Adding a genuinely new piece of dynamic state (a field on NodeState) is a framework-level change, not something a single node can do on its own.
Write tests to verify:
- Shape correctness: Output shapes match the node's declared shape
- Energy decreases: Inference reduces free energy across steps
- Gradient flow: Learning updates weights in the expected direction
Example test:
import jax
import jax.numpy as jnp
from fabricpc.core.topology import Edge
from fabricpc.graph_assembly import graph, TaskMap
from fabricpc.core.inference import InferenceSGD
from fabricpc.graph_initialization import initialize_params
def test_conv2d_energy_decreases():
"""Test that inference reduces energy for Conv2D node."""
# Build graph
input_node = IdentityNode(shape=(28, 28, 1), name="input")
conv = Conv2DNode(shape=(28, 28, 16), kernel_size=(3, 3), name="conv")
output = IdentityNode(shape=(28, 28, 16), name="output")
structure = graph(
nodes=[input_node, conv, output],
edges=[
Edge(source=input_node, target=conv.slot("in")),
Edge(source=conv, target=output.slot("in")),
],
task_map=TaskMap(x=input_node, y=output),
inference=InferenceSGD(eta_infer=0.05, infer_steps=20),
)
# Initialize parameters
rng_key = jax.random.PRNGKey(0)
params = initialize_params(structure, rng_key)
# Create dummy data
batch_size = 4
x = jax.random.normal(rng_key, (batch_size, 28, 28, 1))
y = jax.random.normal(rng_key, (batch_size, 28, 28, 16))
# Run inference and track energy
# ... (see existing test examples in tests/)
# Assert energy decreases
assert final_energy < initial_energyCreating custom nodes involves:
- Subclass
NodeBase: Define your node class - Implement
get_slots(): Specify input slots - Implement
initialize_params(): Allocate and initialize weights/biases - Implement
forward(): Compute predictions, errors, and energy - Optional overrides:
get_weight_fan_in(): For correct muPC scalingforward_and_latent_grads()/forward_and_weight_grads(): For explicit gradients
- Test: Verify shapes, energy convergence, and gradient flow
With these methods in place, your custom node integrates seamlessly with the rest of FabricPC's infrastructure: graph building, inference, learning, and scaling.