Skip to content

Latest commit

 

History

History
489 lines (378 loc) · 18.9 KB

File metadata and controls

489 lines (378 loc) · 18.9 KB

Writing Custom Nodes

FabricPC's built-in nodes cover many common use cases, but you may need custom nodes for specialized computations. This guide walks through creating your own node types.

When to Create a Custom Node

Built-in nodes include:

  • Linear: Fully-connected layers
  • IdentityNode: Passthrough nodes
  • StorkeyHopfield: Associative memory
  • TransformerBlock: Multi-head attention and feedforward

Create a custom node when you need:

  • Gating mechanisms (LSTM, GRU)
  • Custom transfer functions
  • Specialized domain-specific projections
  • Any computation not covered by built-in nodes

The NodeBase Contract

All nodes inherit from NodeBase and must implement three static methods:

  1. get_slots(): Define input slots (the connection interface)
  2. initialize_params(): Allocate and initialize weights and biases
  3. forward(): Compute predictions, errors, and energy

These methods are static because FabricPC uses a functional JAX-based design. Node instances hold configuration; the static methods define pure functional transformations.

forward() is where the node does its real work, and its body is intentionally unconstrained: how you turn inputs into a prediction is up to you. But a fixed set of steps must happen inside it for the node to participate in inference and learning — these are spelled out in Implement Forward Computation below. The split is:

  • Required (every node): produce z_mu, record pre_activation, compute error, write those fields back, populate energy via the energy functional, and return (total_energy, state).
  • Flexible (per node): how inputs are combined (sum, matmul, attention, embedding lookup), whether weights/biases exist, whether and which activation applies, any internal sub-structure (LayerNorm, attention, residual paths), and any extra energy terms.

Step-by-Step: Conv2D Node

Let's build a 2D convolutional node from scratch, to illustrate the node contract. (FabricPC now ships a production ConvNode in fabricpc.nodes; this from-scratch version is for teaching, not for use.)

Step 1: Define the Class and Constructor

import jax.numpy as jnp
import numpy as np
from fabricpc.nodes.base import NodeBase, SlotSpec
from fabricpc.core.types import NodeParams, NodeState, NodeInfo
from fabricpc.core.activations import ReLUActivation
from fabricpc.core.energy import GaussianEnergy
from fabricpc.core.initializers import NormalInitializer, initialize

class Conv2DNode(NodeBase):
    def __init__(
        self,
        shape,
        name,
        kernel_size,
        stride=(1, 1),
        padding="SAME",
        activation=ReLUActivation(),
        energy=GaussianEnergy(),
        latent_init=NormalInitializer(),
        weight_init=NormalInitializer(),
        **kwargs
    ):
        super().__init__(
            shape=shape,
            name=name,
            activation=activation,
            energy=energy,
            latent_init=latent_init,
            weight_init=weight_init,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            **kwargs
        )

Key points:

  • Accept standard node parameters (shape, name, activation, etc.)
  • Accept custom parameters (kernel_size, stride, padding)
  • Pass everything to super().__init__() via **kwargs
  • Custom parameters end up in node_info.node_config and are accessible in static methods

Step 2: Define Input Slots

@staticmethod
def get_slots():
    return {"in": SlotSpec(name="in", is_multi_input=True)}

This defines a single input slot named "in" that accepts multiple incoming edges. The is_multi_input=True flag means contributions from different source nodes will be summed.

For nodes with multiple distinct inputs:

@staticmethod
def get_slots():
    return {
        "in": SlotSpec(name="in", is_multi_input=True),
        "mask": SlotSpec(name="mask", is_multi_input=False),
    }

Step 3: Compute Weight Fan-In (Optional)

For muPC scaling, override get_weight_fan_in() to return the correct fan-in:

@staticmethod
def get_weight_fan_in(source_shape, config):
    kernel_size = config.get("kernel_size", (1, 1))
    C_in = source_shape[-1]  # channels-last format
    return C_in * int(np.prod(kernel_size))

For a 3x3 kernel with 16 input channels: fan_in = 16 * 3 * 3 = 144.

If you don't override this method, the default implementation uses the flattened source shape, which works for fully-connected layers but not for convolutions.

Step 4: Initialize Parameters

@staticmethod
def initialize_params(key, node_shape, input_shapes, weight_init=None, config=None):
    """
    Initialize convolutional kernels and biases.

    Args:
        key: JAX random key
        node_shape: Output shape of this node (H, W, C_out)
        input_shapes: Dict mapping edge_key -> source_shape
        weight_init: Weight initializer
        config: Node configuration dict

    Returns:
        NodeParams(weights=weights_dict, biases=biases_dict)
    """
    kernel_size = config.get("kernel_size", (1, 1))
    C_out = node_shape[-1]  # output channels

    weights = {}
    biases = {}

    for edge_key, source_shape in input_shapes.items():
        C_in = source_shape[-1]  # input channels

        # Kernel shape: (kH, kW, C_in, C_out)
        kernel_shape = (*kernel_size, C_in, C_out)

        # Initialize kernel weights
        subkey, key = jax.random.split(key)
        fan_in = C_in * int(np.prod(kernel_size))
        fan_out = C_out * int(np.prod(kernel_size))

        weights[edge_key] = initialize(
            subkey,
            kernel_shape,
            weight_init,
            fan_in=fan_in,
            fan_out=fan_out,
        )

        # Bias shape: (1, 1, 1, C_out) for broadcasting
        biases[edge_key] = jnp.zeros((1, 1, 1, C_out))

    return NodeParams(weights=weights, biases=biases)

Key points:

  • input_shapes is a dict keyed by edge identifiers (e.g., "conv1->conv2:in")
  • Each edge gets its own weight matrix and bias vector
  • Use the initialize() helper function with the provided weight_init initializer
  • Return NodeParams with dicts of weights and biases

Step 5: Implement Forward Computation

@staticmethod
def forward(params, inputs, state, node_info):
    """
    Forward pass: compute convolution, activation, error, and energy.

    Args:
        params: NodeParams with weights and biases dicts
        inputs: Dict mapping edge_key -> input_array
        state: Current NodeState
        node_info: NodeInfo with configuration

    Returns:
        (total_energy, updated_state)
    """
    # Extract config
    kernel_size = node_info.node_config.get("kernel_size", (1, 1))
    stride = node_info.node_config.get("stride", (1, 1))
    padding = node_info.node_config.get("padding", "SAME")
    activation = node_info.activation

    # Accumulate convolution outputs from all input edges
    pre_activation = None

    for edge_key, input_array in inputs.items():
        kernel = params.weights[edge_key]
        bias = params.biases[edge_key]

        # Perform convolution
        # JAX uses (batch, H, W, C) format
        conv_out = jax.lax.conv_general_dilated(
            lhs=input_array,
            rhs=kernel,
            window_strides=stride,
            padding=padding,
            dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
        )

        conv_out = conv_out + bias

        if pre_activation is None:
            pre_activation = conv_out
        else:
            pre_activation = pre_activation + conv_out

    # Apply activation function
    z_mu = type(activation).forward(pre_activation, activation.config)

    # Compute prediction error
    error = state.z_latent - z_mu

    # Update state
    state = state._replace(
        pre_activation=pre_activation,
        z_mu=z_mu,
        error=error,
    )

    # Compute energy using the energy functional
    node_class = node_info.node_class
    state = node_class.energy_functional(state, node_info)

    # Return total energy and updated state
    total_energy = jnp.sum(state.energy)
    return total_energy, state

forward() is a pure function. It must have no side effects and must express its dependence on params, inputs, and state.z_latent entirely through JAX operations, because the framework differentiates it under jax.value_and_gradforward_and_latent_grads differentiates it with respect to inputs and z_latent, and forward_and_weight_grads with respect to params. Side effects or Python-level control flow on traced values produce wrong gradients during inference and learning.

Within that constraint, every forward() must perform these six steps in order:

  1. Predict z_mu: produce the node's prediction of its own latent, with shape (batch,) + node_info.shape. How is up to the node — a convolution here, a matmul in Linear, an attention pipeline in TransformerBlock.
  2. Record pre_activation: the value before the activation function. If the node applies no activation, set pre_activation = z_mu. pre_activation is planned for deprecation as persistent attribute of NodeState; it's actually an ephemeral intermediate to z_mu.
  3. Compute the error: error = state.z_latent - z_mu. The energy functionals assume this sign (latent minus prediction).
  4. Write the fields back: state = state._replace(z_mu=..., pre_activation=..., error=...). NodeState is a fixed-schema NamedTuple (z_latent, z_mu, error, energy, pre_activation, latent_grad); no other fields exist or may be added.
  5. Populate energy: node_class = node_info.node_class; state = node_class.energy_functional(state, node_info). This sets state.energy from energy(z_latent, z_mu), so z_mu must already be set. Extra energy terms (for example the Hopfield attractor term in StorkeyHopfield) are added by replacing state.energy after this call.
  6. Return: return jnp.sum(state.energy), state — the scalar total energy first, the updated state second.

The steps between predicting z_mu and writing it back are free: input aggregation, weights and biases, the choice of activation, and any internal sub-structure are all node-specific.

muPC scaling is not applied inside forward(). The inference and learning callsites scale inputs and gradients; doing so again here double-scales them. See the Initialization and Scaling guide.

Step 6: Use the Custom Node

from fabricpc.core.topology import Edge
from fabricpc.graph_assembly import graph, TaskMap

# Create nodes
input_node = IdentityNode(shape=(28, 28, 1), name="input")

conv1 = Conv2DNode(
   shape=(26, 26, 16),  # VALID padding: 28-3+1 = 26
   kernel_size=(3, 3),
   stride=(1, 1),
   padding="VALID",
   name="conv1",
)

conv2 = Conv2DNode(
   shape=(24, 24, 32),
   kernel_size=(3, 3),
   stride=(1, 1),
   padding="VALID",
   name="conv2",
)

output_node = Linear(
   shape=(10,),
   flatten_input=True,
   name="output",
)

# Build graph
structure = graph(
   nodes=[input_node, conv1, conv2, output_node],
   edges=[
      Edge(source=input_node, target=conv1.slot("in")),
      Edge(source=conv1, target=conv2.slot("in")),
      Edge(source=conv2, target=output_node.slot("in")),
   ],
   task_map=TaskMap(x=input_node, y=output_node),
   inference=InferenceSGD(eta_infer=0.05, infer_steps=20),
   scaling=MuPCConfig(),
)

Useful Patterns

FlattenInputMixin

For nodes that need dense (fully-connected) behavior with flattened inputs, use the FlattenInputMixin:

from fabricpc.nodes.base import FlattenInputMixin

class MyDenseNode(FlattenInputMixin, NodeBase):
    @staticmethod
    def forward(params, inputs, state, node_info):
        batch_size = state.z_latent.shape[0]
        out_shape = node_info.shape

        # Sum (flattened_input @ weight) over all input edges -> (batch, *out_shape)
        pre_activation = MyDenseNode.compute_linear(
            inputs, params.weights, batch_size, out_shape
        )

        # Add bias, if the node has one
        if "b" in params.biases and params.biases["b"].size > 0:
            pre_activation = pre_activation + params.biases["b"]

        # Apply activation
        z_mu = type(node_info.activation).forward(
            pre_activation, node_info.activation.config
        )

        # ... then compute error, update state, populate energy, and return
        #     (the six required steps above)

The mixin provides:

  • flatten_input(x): Flattens one input tensor from (batch, *shape) to (batch, numel)
  • reshape_output(x_flat, out_shape): Reshapes (batch, numel) back to (batch, *out_shape)
  • compute_linear(inputs, weights, batch_size, out_shape): Sums flattened_input @ weight across all input edges and reshapes to (batch, *out_shape). It does not add a bias — add it yourself, as shown above.

Explicit Gradients

By default, FabricPC computes gradients with JAX autodiff: it differentiates your forward() to obtain both the latent gradients (inference) and the weight gradients (learning). For hand-coded gradients (e.g. for efficiency or control), override forward_and_latent_grads() and forward_and_weight_grads(). These return gradients, not energy, so their signatures differ from forward():

class MyNode(NodeBase):
    @staticmethod
    def forward_and_latent_grads(params, inputs, state, node_info, is_clamped):
        # Run forward() for the updated state, then compute gradients analytically.
        node_class = node_info.node_class
        _, state = node_class.forward(params, inputs, state, node_info)

        # Self-latent gradient dE/dz_latent via the energy functional
        energy_obj = node_info.energy
        self_grad = type(energy_obj).grad_latent(
            state.z_latent, state.z_mu, energy_obj.config
        )

        # Per-edge input gradients dE/d_input (uses activation.derivative())
        input_grads = {...}

        # Returns (updated_state, input_grads, self_grad)
        return state, input_grads, self_grad

    @staticmethod
    def forward_and_weight_grads(params, inputs, state, node_info):
        node_class = node_info.node_class
        _, state = node_class.forward(params, inputs, state, node_info)

        # Compute weight/bias gradients analytically ...
        # Returns (updated_state, NodeParams(weights=..., biases=...))
        return state, NodeParams(weights=weight_grads, biases=bias_grads)

Note that muPC scaling and accumulation into state.latent_grad are handled by the callsite, not inside these overrides. See LinearExplicitGrad (fabricpc/nodes/linear_explicit_grad.py) for a complete example, including its compute_gain_mod_error() helper that combines state.error with activation.derivative().

Multiple Input Slots

For nodes with distinct input types (e.g., data and mask):

@staticmethod
def get_slots():
    return {
        "in": SlotSpec(name="in", is_multi_input=True),
        "mask": SlotSpec(name="mask", is_multi_input=False),
    }

@staticmethod
def forward(params, inputs, state, node_info):
    # Access inputs by slot
    data_inputs = {k: v for k, v in inputs.items() if k.endswith(":in")}
    mask_inputs = {k: v for k, v in inputs.items() if k.endswith(":mask")}

    # Process separately
    # ...

Working Within the Fixed NodeState Schema

NodeState is a fixed-schema NamedTuple with exactly these fields:

class NodeState(NamedTuple):
    z_latent: jnp.ndarray       # latent states (what the network infers)
    z_mu: jnp.ndarray           # predictions (what the network predicts)
    error: jnp.ndarray          # prediction errors (z_latent - z_mu)
    energy: jnp.ndarray         # per-sample energy, shape (batch,)
    pre_activation: jnp.ndarray # values before the activation function
    latent_grad: jnp.ndarray    # gradient accumulator for inference updates

You cannot add custom fields to it. state._replace(...) only updates these existing fields, so there is no place to stash arbitrary per-step memory (such as an RNN hidden vector) on the NodeState.

If your node needs additional state, route it through what already exists:

  • Latent memory: anything the node should "remember" and infer over belongs in z_latent. The inference loop already carries z_latent across steps and updates it from the energy gradient.
  • Inputs: state supplied by other nodes arrives through inputs, keyed by edge. Add an input slot (see Multiple Input Slots) to receive it.
  • Parameters: fixed-per-graph quantities belong in params.weights / params.biases, allocated in initialize_params().

Adding a genuinely new piece of dynamic state (a field on NodeState) is a framework-level change, not something a single node can do on its own.

Testing Your Node

Write tests to verify:

  1. Shape correctness: Output shapes match the node's declared shape
  2. Energy decreases: Inference reduces free energy across steps
  3. Gradient flow: Learning updates weights in the expected direction

Example test:

import jax
import jax.numpy as jnp
from fabricpc.core.topology import Edge
from fabricpc.graph_assembly import graph, TaskMap
from fabricpc.core.inference import InferenceSGD
from fabricpc.graph_initialization import initialize_params


def test_conv2d_energy_decreases():
   """Test that inference reduces energy for Conv2D node."""
   # Build graph
   input_node = IdentityNode(shape=(28, 28, 1), name="input")
   conv = Conv2DNode(shape=(28, 28, 16), kernel_size=(3, 3), name="conv")
   output = IdentityNode(shape=(28, 28, 16), name="output")

   structure = graph(
      nodes=[input_node, conv, output],
      edges=[
         Edge(source=input_node, target=conv.slot("in")),
         Edge(source=conv, target=output.slot("in")),
      ],
      task_map=TaskMap(x=input_node, y=output),
      inference=InferenceSGD(eta_infer=0.05, infer_steps=20),
   )

   # Initialize parameters
   rng_key = jax.random.PRNGKey(0)
   params = initialize_params(structure, rng_key)

   # Create dummy data
   batch_size = 4
   x = jax.random.normal(rng_key, (batch_size, 28, 28, 1))
   y = jax.random.normal(rng_key, (batch_size, 28, 28, 16))

   # Run inference and track energy
   # ... (see existing test examples in tests/)

   # Assert energy decreases
   assert final_energy < initial_energy

Summary

Creating custom nodes involves:

  1. Subclass NodeBase: Define your node class
  2. Implement get_slots(): Specify input slots
  3. Implement initialize_params(): Allocate and initialize weights/biases
  4. Implement forward(): Compute predictions, errors, and energy
  5. Optional overrides:
    • get_weight_fan_in(): For correct muPC scaling
    • forward_and_latent_grads() / forward_and_weight_grads(): For explicit gradients
  6. Test: Verify shapes, energy convergence, and gradient flow

With these methods in place, your custom node integrates seamlessly with the rest of FabricPC's infrastructure: graph building, inference, learning, and scaling.