Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions fabricpc/nodes/skip_connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""
Skip connection node for residual architectures.

SkipConnection is identical to IdentityNode in behavior — it sums inputs
and passes them through — but its slot has ``is_variance_scalable=False``.
This tells muPC to leave edges into this node unscaled (scale = 1.0),
preserving the identity mapping that carries signal through deep networks.
SkipConnection sums its inputs and scales by 1/√N (where N is the number
of inputs), then passes the result through. Its slot has
``is_variance_scalable=False``, so muPC leaves incoming edges at scale 1.0.

When used as a residual connection with N=2 (skip path + compute path),
the 1/√2 scaling keeps output variance ≈ 1 when both branches have
variance ≈ 1, preventing the variance doubling that would occur from a raw
sum. For a single input (N=1) the node behaves as an identity pass-through.

Use SkipConnection for residual/skip paths in your graph. Use IdentityNode
for summation points where all inputs are independent and should be
Expand Down Expand Up @@ -42,11 +46,16 @@

class SkipConnection(NodeBase):
"""
Skip connection node: sums inputs without muPC variance scaling.
Skip connection node: sums inputs with 1/√N variance scaling.

Sums all inputs and scales by 1/√N (where N is the number of connected
inputs), then passes the result through without learnable parameters.
Its slot has ``is_variance_scalable=False``, so muPC leaves incoming
edges at scale 1.0.

Identical to IdentityNode in forward behavior (sums all inputs, no
learnable parameters). Its slot has ``is_variance_scalable=False``,
so muPC leaves incoming edges at scale 1.0.
The 1/√N scaling preserves unit output variance when all N input branches
have unit variance, preventing the variance doubling from a raw sum.
For a single input (N=1) the scaling is a no-op (identity pass-through).

This preserves the identity mapping through deep residual networks.
Without this, muPC's in-degree formula scales skip edges by
Expand Down Expand Up @@ -102,14 +111,20 @@ def forward(
state: NodeState,
node_info: NodeInfo,
) -> Tuple[jax.Array, NodeState]:
"""Sum all inputs and pass through (no transformation)."""
"""Sum all inputs and scale by 1/√N to maintain unit variance."""
pre_activation = None
for edge_key, x in inputs.items():
if pre_activation is None:
pre_activation = x
else:
pre_activation = pre_activation + x

# Scale by 1/√N when N > 1 to keep Var ≈ 1 when all N inputs have Var ≈ 1.
# For a single input (N=1) this is a no-op.
n_inputs = len(inputs)
if n_inputs > 1:
pre_activation = pre_activation / jnp.sqrt(jnp.float32(n_inputs))

z_mu = pre_activation # no activation function applied: z_mu = pre_activation
error = state.z_latent - z_mu
state = state._replace(
Expand Down
60 changes: 49 additions & 11 deletions fabricpc/nodes/transformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
tokens → Embedding → MhaResidual(+) → LnMlp1 → Mlp2Residual(+) → VocabProjection → logits
│ (skip) ↑ │ (skip) ↑
└──────────┘ └───────────┘

Variance control:
- MhaResidualNode: softmax averaging compensated by sqrt(eff_ctx), residual scaled by 1/√2
- Mlp2ResidualNode: residual scaled by 1/√2
- create_deep_transformer: per-node dimension-appropriate initialization
"""

import math
from typing import Dict, Any, Tuple, Optional
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -240,7 +246,15 @@ def proj(h, w_name, b_name):
mha = jnp.matmul(attn, V).transpose(0, 2, 1, 3).reshape(B, L, D)
mha = proj(mha, "W_o", "b_o")

z_mu = x + mha
# Balanced residual: 1/√2 keeps Var ≈ 1 when both branches have Var ≈ 1.
# No explicit sqrt(eff_ctx) compensation is applied here because the QKV
# projections use dimension-appropriate initialization (std=1/sqrt(embed_dim)),
# so pre-softmax scores already have Var ≈ 1. The resulting peaked (non-uniform)
# attention means the attention output inherits Var ≈ Var(V) ≈ 1 directly,
# making position-dependent averaging compensation unnecessary and, in fact,
# harmful (it would over-amplify the last-position outputs).
inv_sqrt2 = jnp.float32(1.0 / jnp.sqrt(2.0))
z_mu = inv_sqrt2 * (x + mha)
error = state.z_latent - z_mu

state = state._replace(z_mu=z_mu, error=error)
Expand Down Expand Up @@ -364,7 +378,10 @@ def forward(params, inputs, state, node_info):
res_in = next(val for key, val in inputs.items() if key.endswith(":residual"))

mlp2 = jnp.dot(mlp1_in, params.weights["W_ff2"]) + params.biases["b_ff2"]
z_mu = res_in + mlp2

# Balanced residual: 1/√2 keeps Var ≈ 1 when both branches have Var ≈ 1.
inv_sqrt2 = jnp.float32(1.0 / jnp.sqrt(2.0))
z_mu = inv_sqrt2 * (res_in + mlp2)

error = state.z_latent - z_mu
state = state._replace(z_mu=z_mu, error=error)
Expand Down Expand Up @@ -446,17 +463,38 @@ def create_deep_transformer(
):
"""
Creates a deep transformer graph using the new class-based builder API.

Initialization strategy (per node type):
- EmbeddingNode: uses ``weight_init`` std (default 0.02), small is fine since
embeddings are normalized by LayerNorm before attention.
- MhaResidualNode: NormalInitializer(std=1/sqrt(embed_dim)) so that Q/K/V/O
projections produce unit-variance outputs at init.
- LnMlp1Node: KaimingInitializer (He init) for the W_ff1 weight to compensate
for GELU halving variance.
- Mlp2ResidualNode: NormalInitializer(std=1/sqrt(mlp_dim)) for unit-variance
W_ff2 output.
- VocabProjectionNode: XavierInitializer for balanced fan-in/fan-out.
"""
# Embedding init: use user-provided std (small std is fine; embeddings are
# normalized by LayerNorm inside each MhaResidualNode before attention).
if weight_init is None:
w_init_obj = NormalInitializer(std=0.02)
embed_init = NormalInitializer(std=0.02)
else:
init_type = weight_init.get("type", "normal")
if init_type == "normal":
w_init_obj = NormalInitializer(std=weight_init.get("std", 0.05))
embed_init = NormalInitializer(std=weight_init.get("std", 0.02))
elif init_type == "xavier":
w_init_obj = XavierInitializer()
embed_init = XavierInitializer()
else:
w_init_obj = KaimingInitializer()
embed_init = KaimingInitializer()

# Dimension-appropriate initializers for each node type.
# These are independent of the user-specified weight_init to guarantee
# unit-variance signal propagation at initialization.
mha_init = NormalInitializer(std=1.0 / math.sqrt(embed_dim))
mlp1_init = KaimingInitializer() # He init compensates for GELU
mlp2_init = NormalInitializer(std=1.0 / math.sqrt(mlp_dim))
proj_init = XavierInitializer()

nodes = []
edges = []
Expand All @@ -471,7 +509,7 @@ def create_deep_transformer(
shape=(seq_len, embed_dim),
vocab_size=vocab_size,
embed_dim=embed_dim,
weight_init=w_init_obj,
weight_init=embed_init,
)
nodes.append(embed_node)
edges.append(Edge(source=input_node, target=embed_node.slot("in")))
Expand All @@ -484,7 +522,7 @@ def create_deep_transformer(
shape=(seq_len, embed_dim),
embed_dim=embed_dim,
num_heads=num_heads,
weight_init=w_init_obj,
weight_init=mha_init,
)
nodes.append(mha)
edges.append(Edge(source=previous_residual, target=mha.slot("in")))
Expand All @@ -495,7 +533,7 @@ def create_deep_transformer(
embed_dim=embed_dim,
ff_dim=mlp_dim,
activation=GeluActivation(),
weight_init=w_init_obj,
weight_init=mlp1_init,
)
nodes.append(mlp1)
edges.append(Edge(source=mha, target=mlp1.slot("in")))
Expand All @@ -505,7 +543,7 @@ def create_deep_transformer(
shape=(seq_len, embed_dim),
embed_dim=embed_dim,
ff_dim=mlp_dim,
weight_init=w_init_obj,
weight_init=mlp2_init,
)
nodes.append(mlp2)
edges.append(Edge(source=mlp1, target=mlp2.slot("in")))
Expand All @@ -518,7 +556,7 @@ def create_deep_transformer(
shape=(seq_len, vocab_size),
vocab_size=vocab_size,
embed_dim=embed_dim,
weight_init=w_init_obj,
weight_init=proj_init,
)
nodes.append(logits)
edges.append(Edge(source=previous_residual, target=logits.slot("in")))
Expand Down
Loading