Skip to content

Add NonErgodicGenerativeProcess and InflatedVocabularyProcess#172

Merged
ealt merged 40 commits intomainfrom
kyle/nonergodic
Mar 17, 2026
Merged

Add NonErgodicGenerativeProcess and InflatedVocabularyProcess#172
ealt merged 40 commits intomainfrom
kyle/nonergodic

Conversation

@kylejray
Copy link
Collaborator

Summary

  • Add NonErgodicGenerativeProcess — a block-diagonal mixture model composing multiple GenerativeProcess components with weighted probabilities, supporting HMM, GHMM, and FactoredGenerativeProcess components
  • Add InflatedVocabularyProcess — a wrapper that multiplies vocab size by factor K with uniform noise, increasing optimal loss by exactly log(K) nats
  • Add builder functions for disjoint and partial-overlap vocabulary configurations with prefix, sliding, and random vocab map modes
  • Use IndependentFactoredGenerativeProcess for independent structures (O(sum V_i) vs O(prod V_i) sampling)
  • Fix off-by-one in NonErgodic generate inference pass (prior vs posterior states)
  • Fix random vocab mode to independently sample tokens per component

Test plan

  • 21 unit tests for NonErgodicGenerativeProcess (block-diagonal structure, vocab mapping, generate/loss, multiple component types)
  • Tests for InflatedVocabularyProcess (vocab inflation, loss offset verification, state dynamics)
  • Tests for builder functions (disjoint/partial-overlap vocab modes, config-driven instantiation)

🤖 Generated with Claude Code

casperlchristensen and others added 24 commits December 9, 2025 10:16
* pylint (#98)

* add compatibility for factored states

* concrete examples and alternating process

* tweaks to vocab sizes

* update naming

* lock

* full merge, renaming

* test factored representation

* finalise gen-process PR

* update after merge

* static analysis

* static analysis tweaks

* arg name

* better test coverage

* factor input args

* ruff

* better linting

* bind i

* elipsis to protocol

* simplify protocol

* format

* Minor fixes

* Minor fixes

* jnp.ndarray -> jax.Array

* Fix JIT compilation issue

Previous code extracted values from JAX arrays and convert to Python ints at runtime. This will fail when the function is JIT-compiled because JAX arrays become tracers during compilation, and int() on a tracer raises an error. The vocab_sizes parameter must be provided to __init__ for this method to work with JIT.

* Refactor generative process config tests to use a helper method for creating factored process configurations. Added parameterized tests for valid and invalid configurations, improving test coverage and maintainability.

* Add docstrings

* Add match strings to value errors in tests

* add better factor handling and allow regression to individual factors

* pass device

* static analysis

* better output format

* to_factor in validation

* update returns and concatenations

* tuple handling

* fix typehint

* improve test coverage

---------

Co-authored-by: ealt <ealt@users.noreply.github.com>
Co-authored-by: Eric Alt <ericallenalt@gmail.com>
* Enhance PyTorch training with metric tracking and update configuration

- Introduced `TrainingMetricTracker` for stateful metric tracking during PyTorch training, allowing for detailed monitoring of loss, learning rates, and parameter updates.
- Updated `train_pytorch_model` to integrate the metric tracker, enabling automatic logging of training metrics.
- Added new metrics to track cumulative and instantaneous values, including loss averages and parameter norms.
- Modified `pyproject.toml` to include `reportUnnecessaryEllipsis` setting and added `diff-cover` as a development dependency.
- Expanded the README with documentation on the new `TrainingMetricTracker` and its usage.
- Added tests for the metric tracker to ensure accurate reporting of metrics during training.

* pylint (#98)

* add compatibility for factored states

* concrete examples and alternating process

* tweaks to vocab sizes

* Create design doc

* Implement solution

* Add plotly support

* Disable too many instance attributes in configs

* Replace dict with {}

* Import altair in a normal way

* Remove init

* Reorganize altair dependency in pyproject.toml

* Fix demo imports

* Refactor metric tracker

* Update metrics

* Add current loss metric enhancements

- Introduced additional metrics for tracking loss: minimum loss, moving average (MA), and exponential moving average (EMA).
- Updated the `compute` method to return these new metrics alongside the current loss.
- Enhanced the distance from initialization metric to track the maximum distance encountered during training.

* Fix bugs with metrics tracker

* Fix loss metrics

* update naming

* Rename metric tracker

* Refactor MetricTracker and metrics initialization

- Removed initial_loss and optimal_loss parameters from MetricTracker constructor.
- Introduced metric_kwargs to pass additional parameters for metrics initialization.
- Updated the _initialize_context and _initialize_metrics methods to accommodate changes.
- Enhanced CurrentLossMetric and LossProgressMetric to use kwargs for initialization, improving flexibility.

* Refactor MetricTracker and MetricContext to unify named parameters handling

- Renamed and consolidated handling of named parameters in MetricTracker and MetricContext.
- Updated methods to use a single `named_parameters` attribute instead of separate current and previous parameters.
- Adjusted metrics computations to reflect the new structure, ensuring consistency across metrics that rely on named parameters.

* Refactor MetricTracker and MetricContext to use unified token count

- Renamed `batch_tokens` and `total_tokens` to `num_tokens` in MetricContext and MetricTracker.
- Updated metrics calculations in TokensMetric, LearningRateWeightedTokensMetric, and GradientWeightedTokensMetric to reflect the new naming convention.
- Enhanced cumulative token tracking for improved clarity and consistency.

* Refactor metrics to use update method and improve computation

- Updated the `compute` method in various metrics to remove context dependency and introduced an `update` method for state management.
- Enhanced metrics such as TokensMetric, LearningRateMetric, and GradientWeightedTokensMetric to maintain internal state for more efficient calculations.
- Added new utility functions for L2 norm calculations across collections of tensors, improving performance and clarity in metric computations.

* Refactor LossProgressMetric to separate update and compute methods

- Introduced an `update` method to manage the current loss state, enhancing clarity and separation of concerns.
- Updated the `compute` method to calculate progress based on the current loss, improving the metric's functionality.

* Update TokensMetric to rename token metrics for clarity

- Changed metric keys from "tokens/batch" and "tokens/total" to "tokens/raw" and "tokens/raw/cumulative" to better reflect their purpose and improve consistency in naming conventions.

* Clear gradients and learning rates after metric computation in GradientWeightedTokensMetric and FisherInformationMetric for improved state management.

* Refactor MetricTracker to enhance metric group handling and requirements management

- Updated MetricTracker to initialize metric groups and requirement flags more efficiently.
- Modified the update method to support group-specific requirements for learning rates, gradients, and named parameters.
- Simplified the initialization of metrics by consolidating logic and improving clarity in the code structure.
- Added `update_every_step` attribute to several metrics for better state management during updates.

* Add logging for missing update keys in MetricTracker

- Introduced logging to warn when required update keys are missing for metric groups.
- Enhanced metric group handling by adding a method to identify missing update keys based on the `update_every_step` attribute.
- Improved clarity in the metric initialization process by consolidating logic for required metrics.

* Refactor L2 norm computation in metrics.py

- Simplified the docstring for the _tensor_collection_l2_norms function to focus on its core functionality.
- Removed unnecessary casting to CPU in the _named_tensor_distance function to streamline tensor operations.

* Refactor metric computations to utilize new utility functions

- Replaced internal L2 norm and distance calculations in metrics.py with calls to the newly defined tensor_collection_l2_norm and named_tensor_distance functions from pytorch_utils.py.
- Updated docstrings for clarity and removed redundant comments to streamline the codebase.

* Refactor MetricTracker and metrics protocol for improved clarity

- Renamed the TrainingMetric protocol to Metric for better alignment with its purpose.
- Updated the MetricTracker's _initialize_metrics method to utilize the new Metric protocol, enhancing type consistency and clarity in metric initialization.

* Refactor metrics to utilize tensor_stack_l2_norm for improved efficiency

- Replaced instances of tensor_collection_l2_norm with tensor_stack_l2_norm in various metrics for optimized L2 norm calculations.
- Simplified the update and compute methods in GradientWeightedTokensMetric, CumulativeParameterUpdateMetric, and FisherInformationMetric to enhance state management and clarity.
- Removed redundant internal functions for L2 norm and distance calculations, streamlining the codebase.

* Remove metric tracker

* add activation analysis work

* reformat tests

* move protocol and inherit

* add example

* less jax conversions

* jax first

* claude feedback

* better types

* fix tests

* fix initialisation from config

* use protocol only for duck-typing

* error handling

* simplified docstrings, unused variables

* pyright protocol

* analyses tweaks

* typing

* refactor: Split `MetricTracker.update` into `step` and `update_metrics`, and optimize tensor operations in `named_tensor_distance`, gradient extraction, and parameter snapshots by removing CPU transfers and vectorizing calculations.

* Add configs and metric tracker in run management

* update pr

* fix uv

* pin for transformer_lens compatibility

* add layerwise analysis classes

* use correct return class

* dataclass access notation

* fix lock

* ruff format

* pylint

* remove unused arg

* fix tests after refactor

* linter happiness

* separate responsibilities of generative processes

* revert

* add activation tracker test

* add activation tracker test

* final feedback

* proper instantiate

* no aliasing

* remove unneeded sklearn

* simplify last token

* fix tests after refactor

* remove unusd dict and handle div by 0

* add tests to analysis functions

* better coverage

* make pyright happy

* add config coverage

* pull out normalization functions

* be more explicit about missing data/features

* PR feedback

* missing docstrings

* make methods public and document

* prepare options

* unnecessary conversion

* missing docstring

* formatting

* use explicit typehints

* use prepare options in tests

* change tests to JNP

* unused import

* update test coverage

* wip activation visualization

* merge

* add lock

* run with scalars

* temporary commit for merge

* static analysis checks

* update after facet-plots

* mute final pylint warnings

* fix final static analyses

* get rid of type alias

* small e2e test only

* use activation tracker in e2e

* fix yaml structure

* remove unused config

* fix end to end tests

* add more coverage

* add more tests

* refactor to more modularity

* training config

* add schedulers, proper bos handling in loop

* remove unnecessary comment

* add schedulers to e2e configs

* handle bos-token behaviour in tests

* get rid of large docs file

* add LR scheduler tests

* add LR schedulers for exact recreations of Adam's plots

* fix pyright

* only little test

* fix colour test

* make altair optional

* make altair obligatory

* simplify conversions

* consolidations

* Delete tests/end_to_end/configs/demo_config_with_visuals.yaml

* Delete tests/end_to_end/configs/demo_config_with_visuals.py

* consolidation (again)

* address PR feedback

* better typing

---------

Co-authored-by: Eric Alt <ericallenalt@gmail.com>
Co-authored-by: ealt <ealt@users.noreply.github.com>
Co-authored-by: Casper Lutzhoft Christensen <casper@g488.voltagepark.net>
* Refactor regression code to incorporate optional computation of pairwise subspace orthogonality metrics

* Refine regression API and add comprehensive orthogonality tests

- Separate coeffs/intercept in return structure (omit intercept key when
  fit_intercept=False)
- Rename to_factors → concat_belief_states for clarity
- Add 9 orthogonality tests with principled numerical thresholds
  (safety_factor=10)
- Test orthogonal, aligned, contained subspaces; multi-factor scenarios;
  edge cases
- Update validators and existing tests for new parameter structure
- Add informative assertion messages for debugging numerical precision

* Organize imports

* Fix lint issues

* Fix slices

* Simplify lr kwarg validation

* Add return type

* Add pylint ignore

* Fix potential division by zero

* Fix potential log(0) issue

* Enhance subspace orthogonality computation by adding a check for multiple belief states. Log a warning if only one belief state is present, preventing unnecessary calculations.

* Fix docstring inconsistency

* Update docstring

* Fix lint issues

* Refactor linear regression kwargs validation and improve logging. Temporarily disable pylint checks during AST traversal to avoid crashes related to package imports.

* Fix merge conflict

* Ammended unseen merge conflict in linear_regression tests

* Rename to_factors parameter to concat_belief_states in activation analyses

* Update activation analysis tests for concat_belief_states semantics

* Fix validator error message and fix linting issues

* Add check requiring 2+ factors in _handle_factored_regression and remove redundant orthogonality compuations warning

* Add proper spacing to warning messages

* Fix dictionary equivalence check in test_linear_regression and add blank line after docstring in test_layerwise_analysis

* Refactor subspace orthogonality computation for JIT compatibility

* Fix conditional callback execution using jax.lax.cond

* Fix linting and formatting issues

* Fix formatting issues

* Disable too-many-locals linting issue in test_linear_regression.py

* Change name of return dict from singular_values -> arrays for clarity

* Add docstring describing return values for _compute_all_pairwise_orthogonality function

* Add docstring describing relevance of the do_nothing_branch function

* Refactor key removal method in kwarg validator and fix docstring format

* Temporarily disable pylint checks during AST traversal in linear_regression.py to prevent crashes. Remove deprecated layer_linear_regression_svd function for cleaner code and encourage use of layer_linear_regression with use_svd=True.

* Refactor linear regression analysis registration to use partial application of layer_linear_regression with use_svd=True, removing the deprecated layer_linear_regression_svd function for improved clarity and consistency.

* Fix tests

* Add detailed docstring to _compute_subspace_orthogonality function, specifying return values and their meanings for improved clarity and documentation.

* Add todo

* Fix kwarg validation

* Fix tests

* Add validator decorator for linear_regression_svd to enforce use_svd=True and exclude it from output. Enhance tests to validate behavior.

* Fix test

* Add get_robust_basis for robust orthonormal basis extraction

* Pass pair of bases instead of coefficient matrices to _compute_subspace_orthogonality

* Compute full rank and orthonormal basis of coeff matrices before passing bases to subspace analysis

* Fix formatting and docstring

* Update comment

* Fix issues due to API changes in activation and dataframe tests

* Fix formatting issues

---------

Co-authored-by: Eric Alt <ericallenalt@gmail.com>
…nalysis and LinearRegressionSVDAnalysis (#140)

* Enhance PyTorch training with metric tracking and update configuration

- Introduced `TrainingMetricTracker` for stateful metric tracking during PyTorch training, allowing for detailed monitoring of loss, learning rates, and parameter updates.
- Updated `train_pytorch_model` to integrate the metric tracker, enabling automatic logging of training metrics.
- Added new metrics to track cumulative and instantaneous values, including loss averages and parameter norms.
- Modified `pyproject.toml` to include `reportUnnecessaryEllipsis` setting and added `diff-cover` as a development dependency.
- Expanded the README with documentation on the new `TrainingMetricTracker` and its usage.
- Added tests for the metric tracker to ensure accurate reporting of metrics during training.

* pylint (#98)

* add compatibility for factored states

* concrete examples and alternating process

* tweaks to vocab sizes

* Refactor metric tracker

* Update metrics

* Add current loss metric enhancements

- Introduced additional metrics for tracking loss: minimum loss, moving average (MA), and exponential moving average (EMA).
- Updated the `compute` method to return these new metrics alongside the current loss.
- Enhanced the distance from initialization metric to track the maximum distance encountered during training.

* Fix bugs with metrics tracker

* Fix loss metrics

* update naming

* Rename metric tracker

* Refactor MetricTracker and metrics initialization

- Removed initial_loss and optimal_loss parameters from MetricTracker constructor.
- Introduced metric_kwargs to pass additional parameters for metrics initialization.
- Updated the _initialize_context and _initialize_metrics methods to accommodate changes.
- Enhanced CurrentLossMetric and LossProgressMetric to use kwargs for initialization, improving flexibility.

* Refactor MetricTracker and MetricContext to unify named parameters handling

- Renamed and consolidated handling of named parameters in MetricTracker and MetricContext.
- Updated methods to use a single `named_parameters` attribute instead of separate current and previous parameters.
- Adjusted metrics computations to reflect the new structure, ensuring consistency across metrics that rely on named parameters.

* Refactor MetricTracker and MetricContext to use unified token count

- Renamed `batch_tokens` and `total_tokens` to `num_tokens` in MetricContext and MetricTracker.
- Updated metrics calculations in TokensMetric, LearningRateWeightedTokensMetric, and GradientWeightedTokensMetric to reflect the new naming convention.
- Enhanced cumulative token tracking for improved clarity and consistency.

* Refactor metrics to use update method and improve computation

- Updated the `compute` method in various metrics to remove context dependency and introduced an `update` method for state management.
- Enhanced metrics such as TokensMetric, LearningRateMetric, and GradientWeightedTokensMetric to maintain internal state for more efficient calculations.
- Added new utility functions for L2 norm calculations across collections of tensors, improving performance and clarity in metric computations.

* Refactor LossProgressMetric to separate update and compute methods

- Introduced an `update` method to manage the current loss state, enhancing clarity and separation of concerns.
- Updated the `compute` method to calculate progress based on the current loss, improving the metric's functionality.

* Update TokensMetric to rename token metrics for clarity

- Changed metric keys from "tokens/batch" and "tokens/total" to "tokens/raw" and "tokens/raw/cumulative" to better reflect their purpose and improve consistency in naming conventions.

* Clear gradients and learning rates after metric computation in GradientWeightedTokensMetric and FisherInformationMetric for improved state management.

* Refactor MetricTracker to enhance metric group handling and requirements management

- Updated MetricTracker to initialize metric groups and requirement flags more efficiently.
- Modified the update method to support group-specific requirements for learning rates, gradients, and named parameters.
- Simplified the initialization of metrics by consolidating logic and improving clarity in the code structure.
- Added `update_every_step` attribute to several metrics for better state management during updates.

* Add logging for missing update keys in MetricTracker

- Introduced logging to warn when required update keys are missing for metric groups.
- Enhanced metric group handling by adding a method to identify missing update keys based on the `update_every_step` attribute.
- Improved clarity in the metric initialization process by consolidating logic for required metrics.

* Refactor L2 norm computation in metrics.py

- Simplified the docstring for the _tensor_collection_l2_norms function to focus on its core functionality.
- Removed unnecessary casting to CPU in the _named_tensor_distance function to streamline tensor operations.

* Refactor metric computations to utilize new utility functions

- Replaced internal L2 norm and distance calculations in metrics.py with calls to the newly defined tensor_collection_l2_norm and named_tensor_distance functions from pytorch_utils.py.
- Updated docstrings for clarity and removed redundant comments to streamline the codebase.

* Refactor MetricTracker and metrics protocol for improved clarity

- Renamed the TrainingMetric protocol to Metric for better alignment with its purpose.
- Updated the MetricTracker's _initialize_metrics method to utilize the new Metric protocol, enhancing type consistency and clarity in metric initialization.

* Refactor metrics to utilize tensor_stack_l2_norm for improved efficiency

- Replaced instances of tensor_collection_l2_norm with tensor_stack_l2_norm in various metrics for optimized L2 norm calculations.
- Simplified the update and compute methods in GradientWeightedTokensMetric, CumulativeParameterUpdateMetric, and FisherInformationMetric to enhance state management and clarity.
- Removed redundant internal functions for L2 norm and distance calculations, streamlining the codebase.

* Remove metric tracker

* refactor: Split `MetricTracker.update` into `step` and `update_metrics`, and optimize tensor operations in `named_tensor_distance`, gradient extraction, and parameter snapshots by removing CPU transfers and vectorizing calculations.

* Add configs and metric tracker in run management

* Simplify

* Refactor metrics and tracker

* Rename step group

* Renames

* Update metric tracker config validation

* Make metric tracker context non-private

* Get initial loss from context

* Add metric tracker to e2e test

* Remove example

* Fix config name

* Cahange dict to mapping to handle DictConfig

* Fix bug in updating lr

* Remove unused return value, simplify method call

* Refactor metric naming conventions for consistency and clarity. Update metric keys to include context and step information, and rename CurrentLossMetric to LossMetric for better understanding.

* Add loss progress to LossMetric

* Refactor requirements formatting in metrics for improved readability and consistency

* Enhance ParameterNormMetric to compute both parameter and weight norms, consolidating metrics into a single return statement. Remove WeightNormMetric class as its functionality is now integrated.

* Rename keys, merge fisher proxy into grad weighted tokens

* Update names

* Enhance MetricTracker and LossMetric to support custom step values, improving flexibility in metric tracking and loss computation.

* Remove step from context

* Add eval metric tracker to training

* Remove weights norm

* Check if metric names is a list config

* add instance to metric tracker keys

* Disable databricks.sdk info logs

* Configure devices to be the same

* Reanme experiment/run names

* Add tokens per second metrics

* Detatch loss before converting to float

* Create full training configs

* Update uv.lock

* ruff format

* Avoid div by zero

* lock

* full merge, renaming

* Fix training test

* test factored representation

* Fix device mismatch

* Device mismatch pt 2

* finalise gen-process PR

* update after merge

* static analysis

* static analysis tweaks

* arg name

* better test coverage

* factor input args

* ruff

* better linting

* bind i

* elipsis to protocol

* simplify protocol

* format

* hack to get training working again

* Simplify components key

* Change metrics returns

* Update optimizer handling to log warnings for multiple optimizers and return None instead of the first optimizer.

* Create tests for requirements

* learning rates metric test

* Tokens metric test

* lr weighted tokens test

* gradient weighted tokens test

* parameter update test

* Have loss progess approach zero instead of one

* loss metric test

* param norm test

* parameter distance test

* uv sync

* Test pytorch utils

* Create metric groups property

* Create metric tracker tests

* add xavier's leaky RRXOR (#130)

* Update workflows to support dev branch ruleset standards

* Update GitHub workflows to correctly reference pull request base branches in conditions

* feat: Add `compute_subspace_orthogonality` option to `LinearRegressionAnalysis` and `LinearRegressionSVDAnalysis` to expose subspace metrics, along with corresponding tests.

---------

Co-authored-by: Casper Lutzhoft Christensen <clu@corti.ai>
Co-authored-by: Casper Lützhøft Christensen <61698286+casperlchristensen@users.noreply.github.com>
* fix slider rendering

* fix reference

* update tests post bug-fix

* static analysis
* Add simplexity-multirun CLI for parallel experiment execution

Add a new CLI tool for running multiple Hydra experiments in parallel
across GPUs or CPU workers with proper device isolation.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix pylint and ruff linting issues

- Add pylint disable comments for too-many-arguments, too-many-locals, etc.
- Initialize variables before conditional to fix possibly-used-before-assignment
- Use raw docstring (r""") for backslash escapes
- Add strict=True to zip() call

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Refactor run_parallel to separate job generation from dispatch

- Add Job dataclass with to_cmd() method for rendering commands
- Extract generate_jobs() as a pure function for testability
- Extract dispatch_jobs() to encapsulate ProcessPoolExecutor logic
- Simplify main() to two-phase structure: generate then dispatch
- Dry-run now exits before dispatch instead of passing through executor
- Add tests for Job and generate_jobs() (GPU round-robin, sweep expansion)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Add missing docstrings to test methods

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

---------

Co-authored-by: adamimos <adam@g093.voltagepark.net>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
* save more path-specific visualizations

* update test path
* generic d_vocab resolution

* rename and test

* static analysis
* Abbreviate linear regression scalar metric names

* Add method to make layer names more compact and to construct layer-specific metric key names

* Integrate layer formatting methods into LayerwiseAnalysis

* Update visualization key lookup to use new {analysis}/{layer} format

Change projection and scalar key resolution to use the new naming convention
where keys follow {analysis}/{layer_spec} format (e.g., "pca/L0.resid.pre")
instead of the old {layer}_{analysis} format (e.g., "layer_0_pca").

Key changes:
- Update _lookup_projection_array and _lookup_scalar_value to match keys
  by prefix (analysis/) rather than suffix (_analysis)
- Add _key_matches_layer helper to handle factor-suffixed keys like
  "projected/layer_0-F0" when given pattern "projected/F0"
- Update _expand_projection_key_pattern to extract factor suffixes from
  new format and reconstruct pattern-matchable keys
- Update _expand_scalar_pattern_keys to properly handle analysis prefix
  for patterns with internal slashes

Update all test files to use new key format in mock data.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Format layer names in visualization lookups and update test key assertions

- Add format_layer_spec to field_resolution.py for converting layer names
  (e.g., blocks.0.hook_resid_pre → L0.resid.pre) before key matching
- Update dataframe_builders.py to format layer names in scalar series
  inference and DataFrame construction
- Update test_linear_regression.py assertions to new key format:
  - factor_X/metric → metric/FX
  - orthogonality_X_Y/metric → orth/metric_short/FX,Y
  - concat/metric → metric/Fcat
- Update test_layerwise_analysis.py assertions to new key format:
  - layer_metric → metric/layer
- Update with_visuals.yaml config templates to match new format
- Update test_activation_tracker_config.py key assertion

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Abbreviate PCA metric names for consistency

- variance_explained → var_exp
- n_components_{pct}pct → nc_{pct}
- cumvar_{idx} unchanged

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix formatting and linting

* Format layer names in pattern expansion for projection key matching

The pattern expansion logic was using unformatted layer names (e.g.,
'blocks.0.hook_resid_pre') to match against projection keys that have
formatted layer names (e.g., 'projected/L0.resid.pre'). This caused
pattern matching to fail when expanding projection key patterns like
'projected/F*' for non-concatenated layers.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Eric/improve-metric-naming-for-length-and-readability (#156)

* simplify metric_keys.py

* Update field resolution

* Remove test

* Simplify pattern expansion

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: ealt <ealt@users.noreply.github.com>
* add xavier's leaky RRXOR (#130)

* reduce number of metrics returned from variance analysis

* rename

* Update simplexity/activations/visualization/pattern_expansion.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* abbreviate

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…#167)

Add support for:
- Top-level hooks (hook_embed → embed)
- Block component hooks (blocks.N.{comp}.hook_X → LN.{comp}.X)
- ln_final hooks (ln_final.hook_X → ln_final.X)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
)

* Add IndependentFactoredGenerativeProcess for frozen factor support

Introduces a new generative process subclass that samples emissions from
each factor independently and supports "frozen" factors whose sequences
are identical across batch samples. This enables generating datasets where
k factors share realizations while (n-k) factors vary independently.

Key features:
- Per-factor independent emission sampling (not from joint distribution)
- Frozen factors specified via frozen_factor_indices and frozen_key
- Dual key stream approach: frozen factors use shared key, unfrozen use per-sample keys

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Simplify IndependentFactoredGenerativeProcess implementation

- Remove _generate_with_frozen method, merge logic into single generate method
- Move None handling to edges (emit_observation and generate) so
  _emit_observation_per_factor always receives valid arrays
- Remove unnecessary super().generate() delegation
- Cleaner code structure with fewer methods and one code path

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix formatting

* Fix pylint warnings in IndependentFactoredGenerativeProcess tests

Replace unnecessary lambdas with direct method references and add
pylint disable for too-few-public-methods on TestStateTransitions.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
* return targets (#163)

* Apply Eric's review suggestions

- Use tuple(map(int, vocab_sizes)) in factored_generative_process.py
- Use math.prod(vocab_sizes) in noisy_channel.py for cleaner code

Co-authored-by: Casper Lützhøft Christensen <casperlchristensen@users.noreply.github.com>

---------

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Casper Lützhøft Christensen <casperlchristensen@users.noreply.github.com>
Implements a truly nonergodic generative process that composes multiple
GenerativeProcess components with weighted mixture probabilities. Key features:

- Block diagonal structure where each component owns disjoint state space
- Wraps each component's generate() to avoid exponential joint state space
- Vocab mapping to handle different component vocabularies
- Support for HMM, GHMM, and FactoredGenerativeProcess components
- Builder function for config-driven instantiation
- Full test coverage with 21 unit tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The inference scan returned posterior states (after seeing each
observation) instead of prior states (before seeing each observation),
misaligning with the base GenerativeProcess.generate contract.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…oise

Wraps any GenerativeProcess by adding a stateless uniform noise dimension,
multiplying vocab size by K. Each base token t becomes K inflated tokens,
testing whether models can discover which part of the token carries state info.
Optimal loss increases by exactly log(K) nats.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…edGenerativeProcess for independent structures

- Add build_nonergodic_disjoint_vocab and build_nonergodic_partial_overlap
  builders with prefix, sliding, and random vocab map modes
- Modify build_factored_process to return IndependentFactoredGenerativeProcess
  for independent structures (O(sum V_i) sampling vs O(prod V_i))
- Add noise_epsilon parameter to IndependentFactoredGenerativeProcess
- Add 18 new tests covering all new functionality

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Previously, random mode applied a global permutation to the prefix maps,
making the overlap structure identical to prefix. Now each component
independently samples V tokens from the global pool, producing genuinely
random overlap patterns.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@kylejray
Copy link
Collaborator Author

@claude please review this PR!

@claude
Copy link

claude bot commented Feb 25, 2026

Claude finished @kylejray's task in 3m 6s —— View job


PR Review: Add NonErgodicGenerativeProcess and InflatedVocabularyProcess

Todo List:

  • Gather context and understand the PR
  • Read changed files
  • Analyze code quality, correctness, and test coverage
  • Post review feedback

Overall this is a well-structured PR with clear semantics and solid test coverage. The core math is correct. I found one meaningful correctness bug, a few design issues, and some style violations to flag.


Critical Bug: _slice_belief_states Temporal Misalignment

File: simplexity/generative_processes/generator.py:107-113

When generate(return_all_states=True) is called on a NonErgodicGenerativeProcess, JAX's scan stacks the inference trajectory into a NonErgodicState where both component_beliefs and component_states[i] have the time dimension. After vmap, shapes are:

  • component_beliefs: [B, T, num_components]
  • component_states[i]: [B, T, state_dim_i]

But _slice_belief_states only slices component_beliefs and leaves component_states unmodified:

if isinstance(belief_states, NonErgodicState):
    return NonErgodicState(
        component_beliefs=belief_states.component_beliefs[:, seq_slice, ...],  # sliced
        component_states=belief_states.component_states,  # ← NOT sliced, still [B, T, ...]
    )

This produces a NonErgodicState with mismatched time dimensions between component_beliefs and component_states whenever generate_data_batch_with_full_history is called. Fix this →

Related: there is no test for NonErgodicGenerativeProcess.generate(return_all_states=True), which is especially important given the off-by-one fix mentioned in the PR description.


Missing Validation: Assumed Equal Vocab Sizes in build_nonergodic_partial_overlap

File: simplexity/generative_processes/builder.py:846

v = temp.components[0].vocab_size  # silently ignores other components

If components have different vocab sizes (e.g. a coin and a mess3), n_shared and n_unique are computed only from component 0. For "sliding" and "random" modes this silently produces incorrect vocab maps. Should validate that all components share the same vocab_size before proceeding, or at minimum document the constraint. Fix this →


Inline Comments Violate CLAUDE.md Style Guidelines

File: simplexity/generative_processes/nonergodic_generative_process.py:17-37 and throughout generate()

CLAUDE.md states: "No Comments: Avoid inline comments; code should be self-documenting through clear naming and structure."

The file contains a large module-level comment block (lines 17–37) explaining the flatten/pad/unflatten strategy, plus extensive inline comments inside generate(). The explanatory content is valuable but should live in the class/method docstrings rather than as free-floating comment blocks.


Double Build in build_nonergodic_disjoint_vocab

File: simplexity/generative_processes/builder.py:765-774

The function builds the entire process once to query vocab sizes, then rebuilds it with the disjoint maps — creating all JAX arrays (transition matrices, inverse maps, etc.) twice. A cleaner pattern would be to extract a helper that builds only the component processes, queries their vocab_size, computes the maps, and then wraps once. Since build_nonergodic_process_from_spec is called twice this is a concrete inefficiency.


Missing Tests

  1. return_all_states=True for NonErgodicGenerativeProcess — No test exercises the inference trajectory path, which is precisely where the off-by-one fix lives. This is the most important missing test.

  2. FactoredGenerativeProcess as a component — The PR explicitly supports factored component types and includes the path in build_nonergodic_process_from_spec, but there is no test for this combination (which uses the flatten/pad/unflatten logic for heterogeneous state types).

  3. generate_data_batch_with_full_history with NonErgodicGenerativeProcess — Would catch the _slice_belief_states bug described above.


Minor Notes

  • build_inflated_process (builder.py:864-877) is a one-line passthrough to InflatedVocabularyProcess(base_process, inflation_factor) with no additional logic. Per CLAUDE.md ("Don't create helpers, utilities, or abstractions for one-time operations"), this wrapper likely doesn't justify its existence.

  • YAML config fields (nonergodic_example.yaml:7-8): base_vocab_size: ??? and vocab_size: ??? are Hydra "required" markers. Since these fields aren't consumed by the builder function, any Hydra instantiation attempt will error unless these are explicitly provided or removed.

  • NonErgodicState.component_states type is tuple[Any, ...]. A union type tuple[jax.Array | tuple[jax.Array, ...], ...] would be more precise and improve IDE/pyright support.


What Looks Good

  • The Bayesian filtering in transition_states is mathematically correct, including the safe-indexing with jnp.clip and the 0-likelihood guard for Bayes update.
  • The lax.switch + flatten/pad/unflatten approach for heterogeneous state types in generate is well-reasoned and the comments in the generate docstring explain it clearly.
  • InflatedVocabularyProcess is clean and the log(K) loss increase property is correctly implemented and tested.
  • The vocab map builder tests are thorough, covering disjoint, prefix, sliding, and random modes with seed reproducibility checks.
  • IndependentFactoredGenerativeProcess is correctly used for independent structures (O(sum V_i) vs O(prod V_i)).

Kyle Ray and others added 4 commits February 26, 2026 02:17
- Fix _slice_belief_states to slice component_states for NonErgodicState
- Add vocab size validation in build_nonergodic_partial_overlap
- Extract _build_components_from_spec to eliminate double-building
- Move inline comments to docstrings in nonergodic_generative_process.py
- Tighten NonErgodicState.component_states type to tuple[ComponentState, ...]
- Add tests for return_all_states, factored components, and full history

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Keep NonErgodicState as an internal detail rather than leaking it
into the public generator interface, which would cascade into
torch_generator and other consumers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@kylejray
Copy link
Collaborator Author

@claude can you take another look?

@claude
Copy link

claude bot commented Feb 26, 2026

Claude finished @kylejray's task in 2m 15s —— View job


Re-reviewed PR #172 — All critical and major issues from the previous review have been addressed. Two minor items remain.

Todo List

  • Gather context
  • Read changed files to assess what's been addressed
  • Compare against previous review findings
  • Post updated review

What Was Fixed

Critical Bug ✅ — _slice_belief_states now correctly slices component_states along axis 1 alongside component_beliefs, and handles nested tuples (FactoredState) via the inner isinstance(cs, tuple) branch in generator.py:107-114.

Missing Validation ✅ — build_nonergodic_partial_overlap now validates equal vocab sizes at builder.py:873-877 before computing n_shared/n_unique.

Inline Comments ✅ — The module-level comment block is gone; the implementation rationale lives in docstrings on _unflatten_state and generate().

Double Build ✅ — build_nonergodic_disjoint_vocab now uses _build_components_from_spec (a new shared helper) so components are built once at builder.py:789.

Missing Tests ✅ — Three new test classes address all gaps:

  • TestGenerateReturnAllStates — verifies shapes and belief validity for return_all_states=True
  • TestFactoredComponent — exercises the flatten/pad/unflatten path with heterogeneous state types
  • TestGenerateDataBatchWithFullHistory — catches _slice_belief_states misalignment, including with BOS token

Type annotation ✅ — component_states is now tuple[ComponentState, ...] where ComponentState = jax.Array | tuple[jax.Array, ...], which is much more precise than tuple[Any, ...].


Remaining Items

1. YAML config: spurious required fields

tests/end_to_end/configs/generative_process/nonergodic_example.yaml:7-8

base_vocab_size: ???  # Will auto-infer as 3
vocab_size: ???

These top-level ??? fields exist outside the instance: block, so they are never passed to build_nonergodic_process_from_spec. However, Hydra treats ??? as "mandatory" — any attempt to load this config without explicitly overriding these keys will throw MissingMandatoryValue. The comment # Will auto-infer as 3 is also incorrect (Hydra cannot infer them). These should either be given concrete values, set to null, or removed entirely. Fix this →

2. build_inflated_process passthrough (minor, judgment call)

builder.py:900-913 is still a one-liner wrapping InflatedVocabularyProcess(base_process, inflation_factor) with no additional logic. Per CLAUDE.md: "Don't create helpers, utilities, or abstractions for one-time operations." build_inflated_process_from_spec is the useful function here (it parses a spec dict); build_inflated_process adds nothing over calling the constructor directly. Removing it keeps the API surface minimal. Either way this is minor and doesn't block the PR.


Overall

The PR is in good shape. The critical correctness fix and all the missing tests are now present. The YAML issue is the only thing I'd call out before merging.
| Branch: kyle/nonergodic

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@kylejray
Copy link
Collaborator Author

@casperlchristensen

I think we are in good shape now, as far as the tests are concerned.

Re: claude's comments:

  • I think the ??? in the yaml is fine and consistent with what we do.
  • I think the inflated_process "passthrough" also is there to preserve the interface pattern of other processes builders. its a trivial helper, but maybe makes sense here. Not sure on that one.

Base automatically changed from dev to main March 3, 2026 20:23
@casperlchristensen
Copy link
Collaborator

@claude Please re-check post fixing merge conflicts!

@claude
Copy link

claude bot commented Mar 3, 2026

Claude finished @casperlchristensen's task in 4m 16s —— View job


Post-Merge Re-check

Todo List:

  • Gather context from previous reviews
  • Inspect current diff vs main
  • Read key changed files
  • Check merge conflict resolution for correctness
  • Post updated review

CI status: ✅ Both workflows passing (End-to-End Tests + main workflow) on the post-merge commits.


Merge Conflict Resolution

Three files had merge conflicts. All are resolved correctly:

simplexity/generative_processes/structures/fully_conditional.py
Main merged in a new indexing.py module with build_other_factor_multipliers and flatten_index. The conflict resolution correctly replaced the old jnp.sum(tokens * mult) call (with an inline comment) with flatten_index(tokens, mult) from the new shared module — equivalent logic, better factored. ✅

simplexity/utils/factoring_utils.py
Main added _radix_multipliers, compute_other_multipliers, and compute_prefix_multipliers to this file. The PR added docstrings to compute_other_multipliers and compute_prefix_multipliers. Conflict resolved by correctly merging both sets of changes. ✅

simplexity/generative_processes/structures/conditional_transitions.py
PR's docstring additions cleanly merged with main's changes. ✅


Previously Flagged Items — Status

All critical and major items from the two prior reviews remain resolved. The two minor open items from the second review:

  • YAML ??? fields — The misleading # Will auto-infer as 3 comment was removed in commit 2520382. The ??? markers remain (PR author's call to keep them consistent with existing configs). ✅ addressed.
  • build_inflated_process passthrough — Still present; PR author keeps it for API consistency. Minor, judgment call. No regression.

New Observation: assert seed is not None (minor, correct)

builder.py:887 has assert seed is not None inside the elif mode == "random": branch, after the earlier guard:

if mode == "random" and seed is None:
    raise ValueError("seed is required when mode='random'")

The assert is retained to narrow seed: int | None for pyright, since the combined condition in the earlier guard doesn't allow pyright to infer seed is not None at the elif branch. This is idiomatic and correct — no issue.


Scope Note

The PR title says "Add NonErgodicGenerativeProcess and InflatedVocabularyProcess" but the diff includes a large new visualization system that wasn't mentioned in the title or description:

  • simplexity/visualization/ — 6 new files (~2,500 lines including a 1,338-line plotly_renderer.py)
  • simplexity/activations/visualization/ — 8 new files (~2,100 lines)

The previous reviews did not cover this code. The associated tests pass in CI. If the visualization system is intended to be part of this PR, it may warrant a dedicated description section or a separate PR for easier review.


Overall

The merge conflict resolution is clean — no correctness regressions introduced. All previously flagged critical issues remain fixed. The PR is in the same good state as after the second review. The only new consideration is whether the visualization code should be reviewed/scoped separately, which is a process question rather than a correctness concern.
| Branch: kyle/nonergodic

@casperlchristensen
Copy link
Collaborator

@claude The visualization was not meant to be added -- it has been re-deleted!

@claude
Copy link

claude bot commented Mar 4, 2026

Claude finished @casperlchristensen's task in 33s —— View job


Post-Merge Re-check (visualization removed)

Todo List:

  • Gather context from previous reviews
  • Check current diff vs main to confirm visualization is removed
  • Review remaining changed files for any new issues
  • Post updated review

Visualization removed ✅ — The re-delete visualization commit (e67ab10) correctly removes all visualization files. The current diff vs origin/main contains exactly the 16 files listed in the PR description — no more, no less.

CI status: Both workflows are currently in progress (22648642063, 22648642054). Will need to pass before merging.


Current State

All issues from the previous three reviews remain resolved. There are no new findings introduced by the visualization deletion. The PR is scoped to exactly what the description says:

  • NonErgodicGenerativeProcess + InflatedVocabularyProcess + builder functions
  • Associated tests (21 unit tests for NonErgodic, tests for InflatedVocabulary, builder tests)
  • Merge-conflict resolution in fully_conditional.py, factoring_utils.py, conditional_transitions.py (all clean)
  • run_parallel.py and pyproject.toml changes

The two minor items from the second review (??? YAML fields, build_inflated_process passthrough) were triaged as intentional by the PR author — no action needed.

The PR looks good to merge once CI passes.


| Branch: kyle/nonergodic

@casperlchristensen
Copy link
Collaborator

@ealt I'm happy with this now!

@kylejray kylejray enabled auto-merge March 13, 2026 01:40
ealt and others added 4 commits March 16, 2026 13:59
…e error checking. Update _build_prefix_vocab_maps for clarity and efficiency. Add tests for mismatched vocab maps and duplicate entries.
…p and introduce new utility functions for state management in NonErgodicGenerativeProcess
@kylejray kylejray added this pull request to the merge queue Mar 17, 2026
@ealt ealt removed this pull request from the merge queue due to a manual request Mar 17, 2026
@ealt ealt merged commit 4075750 into main Mar 17, 2026
4 checks passed
@ealt ealt deleted the kyle/nonergodic branch March 17, 2026 00:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants