Improve transformer_v2 model and add BPE support #28
Open
quinbez wants to merge 5 commits into
Open
Conversation
- Add MuPCConfig scaling to graph builder for variance control - Add causal_mask node wired to all MhaResidualNode mask slots - Add skip slot to MhaResidualNode so muPC counts residual depth correctly - Wire causal_mask into TaskMap for autoregressive training - Switch to train_autoregressive / evaluate_autoregressive / generate_autoregressive - Add use_causal_mask to train_config
- Add MuPCConfig(include_output=False) scaling to graph builder for variance control - Switch output node energy from KLDivergenceEnergy to CrossEntropyEnergy - Add FeedforwardStateInit to graph builder for better latent state initialization - Fix embedding weight init to NormalInitializer(std=1.0) for distinct token vectors - Fix output projection weight init to NormalInitializer(std=sqrt(1/embed_dim)) - Switch InferenceSGD to InferenceSGDNormClip(max_norm=5.0, latent_decay=0.0) - Add cosine decay schedule to optimizer for stable multi-epoch training - Add model parameter count logging - Add evaluation time logging - Replace generate_autoregressive with manual generation loop that correctly clamps causal_mask at each step and samples from log_probs instead of probs
…a for FabricPC models. - Phase 1: Architecture search: minimize energy, prune unstable trials early. - Phase 2: Continuous fine-tuning: fix architecture, minimize perplexity.
- add BPE dataloaders, tokenizer selection, tuned BPE defaults - switch to `generate_autoregressive` for text generation with `jax.lax.scan` sliding window - scale minimum `infer_steps` with model depth during tuning - move tuning outputs under `fabricpc/tuning`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR improves the FabricPC Transformer V2 training pipeline, tuning workflow, and tokenization support. The changes introduce causal masking throughout the transformer graph, improve PC inference stability through
muPCscaling and initialization changes, add a two-phase Bayesian hyperparameter tuning workflow using Optuna, and extend the training pipeline to support both character-level and BPE tokenization.Key Changes
1. Transformer Graph and Autoregressive Training
Updated the Transformer V2 graph builder and training pipeline to support autoregressive language modeling.
muPCscaling (MuPCConfig) for variance control across layersmuPCcorrectly accounts for residual depthTaskMapand training pipeline2. Stability and Training Improvements
Improved optimization stability and predictive coding inference dynamics.
KLDivergenceEnergywithCrossEntropyEnergyat the output node, which is better suited for discrete token predictionFeedforwardStateInitto initialize latent states from a forward pass rather than random noise, giving inference a better starting pointstd=0.02InferenceSGDwithInferenceSGDNormClip (max_norm=5.0)to prevent latent state divergencealpha=0.13. Two-Phase Bayesian Hyperparameter Tuning
Introduced an Optuna-based two-phase tuning workflow.
min_infer_steps = depth * 3 + 2to ensure adequate inference budget per node.fabricpc/tuning/4. BPE Tokenization Support
Added
BpeDataLoaderin the library. On first use, the tokenizer is trained on all splits of Tiny Shakespeare and cached to disk. Subsequent runs load directly from cache.--tokenizerargumentuse_bpeflag5. Autoregressive Generation
generate_autoregressive, which usesjax.lax.scanwith a sliding window for efficient fixed-shape generation.Results
Run 1: Character-Level Tokenizer
Sample generation:
ROMEO: hhon cavenzINLOMAch bexphers the perMervin ougARI bure th Cow ale. And alest gorfind cimengu Goukshay beme cave winchoube way ag k ureve the nomeaghil more mear fit Cot song f whe moree frks loont wRun 2: BPE Tokenizer
Sample generation:
ROMEO : not those of these a him them of of his bed be , heart to prolong them a no with this himself them not not of Mercy breath go but ! to of heaven them him ofThe large gap between train and test perplexity indicates overfitting. This is a known issue and is being addressed in future work.
Future Work
Feedback on FabricPC
FabricPC lives up to its goal of making state-of-the-art PC accessible. The graph-based structure is clean and easy to understand. Being able to reason about each node and connection explicitly made it much easier to build and debug the transformer. We have been working with PyTorch for a while and had to build many utility functions for PC manually. FabricPC removed most of that overhead and let us focus on the architecture and training dynamics instead.