Skip to content

Improve transformer_v2 model and add BPE support #28

Open
quinbez wants to merge 5 commits into
mainfrom
feature/transformer_block_v2
Open

Improve transformer_v2 model and add BPE support #28
quinbez wants to merge 5 commits into
mainfrom
feature/transformer_block_v2

Conversation

@quinbez

@quinbez quinbez commented Jun 6, 2026

Copy link
Copy Markdown
Collaborator

This PR improves the FabricPC Transformer V2 training pipeline, tuning workflow, and tokenization support. The changes introduce causal masking throughout the transformer graph, improve PC inference stability through muPC scaling and initialization changes, add a two-phase Bayesian hyperparameter tuning workflow using Optuna, and extend the training pipeline to support both character-level and BPE tokenization.

Key Changes

1. Transformer Graph and Autoregressive Training
Updated the Transformer V2 graph builder and training pipeline to support autoregressive language modeling.

  • Added muPC scaling (MuPCConfig) for variance control across layers
  • Added a dedicated causal mask node and routed masks to all attention blocks
  • Added skip-slot routing so muPC correctly accounts for residual depth
  • Wired causal masks through TaskMap and training pipeline
  • Switched to autoregressive training, evaluation, and generation APIs

2. Stability and Training Improvements
Improved optimization stability and predictive coding inference dynamics.

  • Replaced KLDivergenceEnergy with CrossEntropyEnergy at the output node, which is better suited for discrete token prediction
  • Added FeedforwardStateInit to initialize latent states from a forward pass rather than random noise, giving inference a better starting point
  • Improved embedding and output projection initialization while keeping transformer weights at std=0.02
  • Replaced InferenceSGD with InferenceSGDNormClip (max_norm=5.0) to prevent latent state divergence
  • Added cosine learning rate decay with alpha=0.1

3. Two-Phase Bayesian Hyperparameter Tuning
Introduced an Optuna-based two-phase tuning workflow.

  • Phase 1: architecture search minimizing energy with early pruning of unstable trials
  • Phase 2: fine-tuning continuous hyperparameters with fixed architecture, optimizing for perplexity
  • Scaled minimum inference steps with transformer depth using min_infer_steps = depth * 3 + 2 to ensure adequate inference budget per node.
  • Moved tuning outputs under fabricpc/tuning/

4. BPE Tokenization Support
Added BpeDataLoader in the library. On first use, the tokenizer is trained on all splits of Tiny Shakespeare and cached to disk. Subsequent runs load directly from cache.

  • Added tokenizer selection between character and BPE modes via --tokenizer argument
  • Added tuned default hyperparameters for each tokenizer
  • Extended tuning pipeline to support BPE experiments via a use_bpe flag

5. Autoregressive Generation

  • Replaced the manual per-token generation loop with generate_autoregressive, which uses jax.lax.scan with a sliding window for efficient fixed-shape generation.

Results

Run 1: Character-Level Tokenizer

Parameter Value
embed_dim 64
num_heads 4
mlp_dim 512
depth 3
infer_steps 18
eta_infer 8.90 × 10⁻²
lr 6.71 × 10⁻⁵
weight_init_std 4.39 × 10⁻²
Metric Value
Train Perplexity (Epoch 5) 1.59
Test CE Loss 2.79
Test Perplexity 16.35
Test Accuracy 28.48%
Training Time 2479s

Sample generation: ROMEO: hhon cavenzINLOMAch bexphers the perMervin ougARI bure th Cow ale. And alest gorfind cimengu Goukshay beme cave winchoube way ag k ureve the nomeaghil more mear fit Cot song f whe moree frks loont w

Note: This run used the manual generation loop. Generation was later switched to generate_autoregressive for Run 2.


Run 2: BPE Tokenizer

Parameter Value
embed_dim 256
num_heads 8
mlp_dim 512
depth 4
infer_steps 30
eta_infer 8.74 × 10⁻²
lr 4.83 × 10⁻⁵
weight_init_std 1.94 × 10⁻²
Metric Value
Train Perplexity (Epoch 5) 26.19
Test CE Loss 6.78
Test Perplexity 876.48
Test Accuracy 11.31%
Training Time 6643s

Sample generation: ROMEO : not those of these a him them of of his bed be , heart to prolong them a no with this himself them not not of Mercy breath go but ! to of heaven them him of

The large gap between train and test perplexity indicates overfitting. This is a known issue and is being addressed in future work.

Future Work

  • Investigate overfitting through dropout, weight decay, and reduced training duration
  • Benchmark on larger datasets that are better suited to BPE tokenization
  • Measure perplexity sensitivity to inference steps (3, 5, and 10 steps per node)
  • Compare PC training directly against the backpropagation trainer on identical architectures

Feedback on FabricPC

FabricPC lives up to its goal of making state-of-the-art PC accessible. The graph-based structure is clean and easy to understand. Being able to reason about each node and connection explicitly made it much easier to build and debug the transformer. We have been working with PyTorch for a while and had to build many utility functions for PC manually. FabricPC removed most of that overhead and let us focus on the architecture and training dynamics instead.

quinbez added 5 commits June 6, 2026 13:29
- Add MuPCConfig scaling to graph builder for variance control
- Add causal_mask node wired to all MhaResidualNode mask slots
- Add skip slot to MhaResidualNode so muPC counts residual depth correctly
- Wire causal_mask into TaskMap for autoregressive training
- Switch to train_autoregressive / evaluate_autoregressive / generate_autoregressive
- Add use_causal_mask to train_config
- Add MuPCConfig(include_output=False) scaling to graph builder for variance control
- Switch output node energy from KLDivergenceEnergy to CrossEntropyEnergy
- Add FeedforwardStateInit to graph builder for better latent state initialization
- Fix embedding weight init to NormalInitializer(std=1.0) for distinct token vectors
- Fix output projection weight init to NormalInitializer(std=sqrt(1/embed_dim))
- Switch InferenceSGD to InferenceSGDNormClip(max_norm=5.0, latent_decay=0.0)
- Add cosine decay schedule to optimizer for stable multi-epoch training
- Add model parameter count logging
- Add evaluation time logging
- Replace generate_autoregressive with manual generation loop that correctly clamps causal_mask at each step and samples from log_probs instead of probs
…a for FabricPC models.

- Phase 1: Architecture search: minimize energy, prune unstable trials early.
- Phase 2: Continuous fine-tuning: fix architecture, minimize perplexity.
- add BPE dataloaders, tokenizer selection, tuned BPE defaults
- switch to `generate_autoregressive` for text generation with `jax.lax.scan` sliding window
- scale minimum `infer_steps` with model depth during tuning
- move tuning outputs under `fabricpc/tuning`
@quinbez quinbez requested a review from matthewbehrend as a code owner June 6, 2026 10:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant