Lucius Bushnaq (Goodfire -- Work primarily carried out while at Apollo Research), Dan Braun (Goodfire -- Work primarily carried out while at Apollo Research), Lee Sharkey (Goodfire)**
Correspondence: lucius@goodfire.ai, dan.braun@goodfire.ai, lee@goodfire.ai
A key step in reverse engineering neural networks is to decompose them into simpler parts that can be studied in relative isolation. Linear parameter decomposition—a framework that has been proposed to resolve several issues with current decomposition methods—decomposes neural network parameters into a sum of sparsely used vectors in parameter space. However, the current main method in this framework, Attribution-based Parameter Decomposition (APD), is impractical on account of its computational cost and sensitivity to hyperparameters. In this work, we introduce Stochastic Parameter Decomposition (SPD), a method that is more scalable and robust to hyperparameters than APD, which we demonstrate by decomposing models that are slightly larger and more complex than was possible to decompose with APD. We also show that SPD avoids other issues, such as shrinkage of the learned parameters, and better identifies ground truth mechanisms in toy models. By bridging causal mediation analysis and network decomposition methods, this demonstration opens up new research possibilities in mechanistic interpretability by removing barriers to scaling linear parameter decomposition methods to larger models. We release a library for running SPD and reproducing our experiments at https://github.com/goodfire-ai/spd.
We have little understanding of the internal mechanisms that neural networks learn that enable their impressive capabilities. Understanding—or reverse engineering—these mechanisms may enable us to better predict and design neural network behavior and propensities for the purposes of safety and control. It may also be useful for scientific knowledge discovery: Neural networks can often perform better than humans on some tasks. They must therefore 'know' things about the world that we do not know—things that we could uncover by understanding their mechanisms.
An important first step to reverse engineering neural networks is to decompose them into individual mechanisms whose structure and interactions can be studied in relative isolation. Previous work has taken a variety of approaches to network decomposition. A popular approach is sparse dictionary learning (SDL) [Cunningham et al., 2023; Bricken et al., 2023], which aims to decompose neural network activations by optimizing sparsely activating dictionary elements to reconstruct or predict neural activation vectors. However, this approach suffers from a range of conceptual and practical problems, such as failing to account for feature geometry [Leask et al., 2025; Mendel et al., 2024] and not decomposing networks into functional components [Chanin et al., 2024; Bricken et al., 2023; Till et al., 2024] (see [Sharkey et al., 2025] for a review).
Recently, linear parameter decomposition [Braun et al., 2025], has been proposed to address some of the issues faced by SDL and other current approaches. Instead of decomposing networks into directions in activation space, linear parameter decomposition methods decompose networks into vectors in parameter space, called parameter components. Parameter components are selected such that, simultaneously, (a) they sum to the parameters of the original model, (b) as few as possible are required to replicate the network's behavior on any given input, and (c) they are as 'simple' as possible. This approach promises a framework that suggests solutions to issues like 'feature splitting' [Chanin et al., 2024; Bricken et al., 2023]; the foundational conceptual issue of defining a 'feature' (by re-basing it in the language of 'mechanisms'); and the issues of multidimensional features and feature geometry [Braun et al., 2025]. It also suggests a new way to bridge mechanistic interpretability and causal mediation analysis [Mueller et al., 2024].
However, Attribution-based Parameter Decomposition (APD) [Braun et al., 2025], the only method that has been so far proposed for linear parameter decomposition (with which this paper assumes some familiarity), suffers from several significant issues that hinder its use in practice, including:
- Scalability: APD has a high memory cost, since it decomposes a network into many parameter components, each of which is a whole vector in parameter space. They therefore each have the same memory cost as the original network.
- Sensitivity to hyperparameters: In the toy models it was tested on, APD only recovers ground-truth mechanisms for a very narrow range of hyperparameters. In particular, APD requires choosing the top-$k$ hyperparameter, the expected number of active parameter components per datapoint, which would usually not be known in advance for non-toy models. As discussed in [Braun et al., 2025], choosing a value for top-$k$ that is too high or low makes it difficult for APD to identify optimal parameter components.
- Use of attribution methods: APD relies on attribution methods (e.g. gradient attributions, used in [Braun et al., 2025]), to estimate the causal importance of each parameter component for computing the model's outputs on each datapoint. Gradient-based attributions, and attribution methods more generally, are often poor approximations of ground-truth causal importance [Syed et al., 2023] and sometimes fail to pass basic sanity checks [Adebayo et al., 2020].
In this work, we introduce a new method for linear parameter decomposition that overcomes all of these issues: Stochastic Parameter Decomposition (SPD).
Our approach decomposes each matrix in a network into a set of rank-one matrices called subcomponents. The number of rank-one matrices can be higher than the rank of the decomposed matrix. Subcomponents are not full parameter components as in the APD method, but they can later be aggregated into full components. In this work, we use toy models with known ground-truth mechanisms, where the clusters are therefore straightforward to identify. However, in future it will be necessary to algorithmically cluster these components in cases where ground truth is not known.
Instead of relying on attribution techniques and a top-$k$ hyperparameter that needs to be chosen in advance, we define the causal importance of a subcomponent as how ablatable it is on a given datapoint. Causally important subcomponents should not be ablatable, and ablatable subcomponents should be causally unimportant for computing the output. We train a causal importance function to predict the causal importance
Crucially, we regularize the predicted causal importance values
We apply SPD to all of the toy models that [Braun et al., 2025] used to study APD, including: A Toy Model of Superposition [Elhage et al., 2022]; a Toy Model of Compressed Computation [Braun et al., 2025]; and a Toy Model of Cross-Layer Distributed Representations. We demonstrate that the method recovers ground-truth mechanisms in all of these models. We also extend the suite of models to include two more challenging models where APD struggles but SPD succeeds: A Toy Model of Superposition with an additional identity matrix in the hidden space and a deeper Toy Model of Cross-Layer Distributed Representations. Using APD, these new models were unmanageably difficult to correctly decompose, but SPD succeeds with relative ease.
The successful application of SPD to more challenging models demonstrates that SPD is more scalable and stable than APD. Nevertheless, some challenges remain: Firstly, the method needs to be scaled to larger models, which will likely require further improvements in training stability. Second, SPD only finds rank-one components in individual layers, meaning that further clustering step is required to find components that span more than one rank and/or more than one layer. In the toy models presented in this paper, these clusters are known and are therefore straightforward to identify. However, a general clustering solution will be needed in order to find such components where ground-truth is unknown. Despite these challenges, SPD opens up new research avenues for mechanistic interpretability by introducing a linear parameter decomposition method that removes the main barriers to scaling to larger, non-toy models such as language models.
Suppose we have a trained neural network
- Faithfulness: The parameter components should sum to the parameters of the original network.
- Minimality: As few parameter components as possible should be used by the network for a forward pass of any given datapoint in the training dataset.
- Simplicity: Parameter components should use as little computational machinery as possible, in that they should span as few matrices and as few ranks as possible.
If a set of parameter components exhibit these three properties, we say that they comprise the network's mechanisms. In APD, gradient-based attributions are used to estimate the importance of each parameter component for a given datapoint. Then, the top-$k$ most important parameter components are summed together and used for a second forward pass. These active parameter components are trained to produce the same output on that datapoint as the target model. Simultaneously, the parameter components are trained to sum to the parameters of the target model, and are trained to be simple by penalizing the sum of the spectral
In our work, we aim to identify parameter components with the same three properties, but we achieve it in a different way.
A major issue with the APD method is that it is computationally very expensive: It involves optimizing
Here,
Note that the number of subcomponents in each layer
The way we optimize for faithfulness is the same as in [Braun et al., 2025], by optimizing the sum of our subcomponents to approximate the parameters of the target model:
where
Optimizing for minimality and simplicity by learning a causal importance function to stochastically sample masks
The way we optimize for minimality and simplicity is different from [Braun et al., 2025]. Since we already start with rank-one subcomponents that are localized in single layers, we don't need to optimize for subcomponent simplicity. Instead, we only need to train our set of subcomponents such that as few as possible are "active" or "used" or "required" by the network to compute its output on any given datapoint in the training set. We consider this equivalent to requiring that as few subcomponents as possible be causally important for computing the network's output.
To optimize our set of subcomponents such that as few as possible are causally important for computing the model's output on any given datapoint, we have three requirements:
- A formal definition of what it means for a subcomponent to be 'causally important' for computing the model's outputs.
- A loss function that trains causally important subcomponents to compute the same function as the original network.
- A loss function that encourages as many subcomponents as possible to be causally unimportant on each datapoint.
Intuitively, we say a subcomponent is causally important on a particular datapoint
Formally, suppose
Now we need a differentiable loss function(s) in order to be able to train the masked model
A loss function that lets us optimize causally important subcomponents to approximate the same function as the original network
Unfortunately, calculating
Here,
However, this loss can be somewhat noisy because it involves a forward pass in which every subcomponent has been multiplied by a random mask. Therefore, in addition to this loss, we also use an auxiliary loss $\mathcal{L}{\text{stochastic-recon-layerwise}}$, which is simply a layerwise version of $\mathcal{L}{\text{stochastic-recon}}$ where only the parameters in a single layer at a time are replaced by stochastically masked subcomponents. The gradients are still calculated at the output of the model:
This should not substantially alter the global optimum of training, because the layerwise loss is equivalent to the full loss if the subcomponents sum to the original weights and if we sample
We have not yet defined how we obtain subcomponents' causal importance values
In theory, we could use any arbitrary causal importance function
The output of the causal importance function for each subcomponent is therefore a single scalar number that should be in the range
While this form of causal importance function works well for the toy models in this paper, it is likely that the best causal importance functions for arbitrary models require more expressivity. For example, a subcomponent's causal importance function may take as input the inner activations of all subcomponents, rather than just its own. Such variants may be explored in future work.
The causal importance values
where
Our full loss function consists of four losses:
Our training setup involves five hyperparameters (excluding optimizer hyperparameters such as learning rate): The coefficients
We apply SPD to decompose a set of toy models with known ground-truth mechanisms. Some of these models were previously studied by [Braun et al., 2025] to evaluate APD, while others are new. We study:
- A Toy Model of Superposition (TMS) [Elhage et al., 2022] previously studied by [Braun et al., 2025];
- A TMS model with an identity matrix inserted in the middle of the model (not previously studied);
- A Toy Model of Compressed Computation previously studied by [Braun et al., 2025];
- Two Toy Models of Cross-Layer Distributed Representations: One with two residual MLP blocks (previously studied by [Braun et al., 2025]) and one with three MLP blocks (not previously studied);
In all cases, we find that SPD seems to identify known ground-truth mechanisms up to a small error. For the models that were also decomposed with APD in [Braun et al., 2025], we find that the SPD decompositions have fewer errors despite requiring less hyperparameter tuning to find the ground-truth mechanisms.
Code to reproduce our experiments can be found at https://github.com/goodfire-ai/spd. Training details and hyperparameters can be found in the Appendix. Additional figures and training logs can be found in the WandB report.
Figure 1: Results of running SPD on TMS_{5-2}. Top row: Plots of (left to right) the columns of the weight matrix of the target model; the sum of the SPD parameter components; and individual parameter components. Although this run of SPD used 20 subcomponents, only 6 subcomponents are shown, ordered by the sum of the norms of each of the columns of their (rank-one) weight matrices. The first five have learned one direction each, each corresponding to one of the columns of the target model. The final column and the other 14 components (not shown) have a negligible norm because they are superfluous for replicating the behavior of the target model. Bottom row: Depiction of the corresponding parametrized networks.
| Model | MMCS | ML2R |
|---|---|---|
| TMS_{5-2} | 1.000 ± 0.000 | 0.993 ± 0.002 |
| TMS_{40-10} | 1.000 ± 0.000 | 1.010 ± 0.007 |
| TMS_{5-2+ID} | 1.000 ± 0.000 | 0.992 ± 0.010 |
| TMS_{40-10+ID} | 1.000 ± 0.000 | 1.031 ± 0.001 |
Table 1: Mean Max Cosine Similarity (MMCS) and Mean L2 Ratio (ML2R) with their standard deviations (to 3 decimal places) between learned parameter subcomponents and the target model weights in the subcomponents found by SPD for the embedding matrix W matrix in the TMS_{5-2} and TMS_{40-10} models and TMS_{5-2+ID} and TMS_{40-10+ID} models. These results indicate that the ground truth mechanisms are recovered perfectly and with negligible shrinkage for all models.
We decompose [Elhage et al., 2022]'s Toy Model of Superposition (TMS), which can be written as
Typically,
The ground truth mechanisms in this model should be a set of rank-$1$ matrices that are zero everywhere except in the
We apply SPD to TMS models with
We quantify how aligned the learned parameter components vectors are to the columns of
where
We also quantify how close their magnitudes are with the mean L2 Ratio (ML2R) between the Euclidean norm of the columns of
where
For both $\text{TMS}{5-2}$ and $\text{TMS}{40-10}$, the MMCS and ML2R values are
Toy Model of Superposition with hidden identity
SDL methods are known to suffer from the phenomenon of 'feature splitting', where the features that are learned depend on the dictionary size, with larger dictionaries finding more sparsely activating, finer-grained features than smaller dictionaries [Bricken et al., 2023; Chanin et al., 2024]: Suppose a network has a hidden layer that simply implements a linear map. When decomposing this layer with a transcoder, we can continually increase the number of latents and learn ever more sparsely activating, ever more fine-grained latents to better minimize its reconstruction and sparsity losses. This problem is particularly salient in the case of a linear map, but similar arguments apply to nonlinear maps.
By contrast, [Braun et al., 2025] claimed that linear parameter decomposition methods do not suffer from feature splitting. In the linear case, SPD losses would be minimized by learning a single
Here, we empirically demonstrate this claim in a simple setting. We train a toy model of superposition identical to the $\text{TMS}{5-2}$ and $\text{TMS}{40-10}$, but with identity matrices inserted between the down-projection and up-projection steps of the models. These models, denoted $\text{TMS}{5-2+\text{ID}}$ and $\text{TMS}{40-10+\text{ID}}$, can be written as
We should expect SPD to find
SPD Results: Toy Model of Superposition with hidden identity
As expected, we find that SPD decomposes the embedding matrices
Also as expected, SPD decomposes the identity matrix
Figure 2: Plots of (left to right) the columns of the input weight matrix W of the TMS_{5-2+ID} model (which can be written as $\hat{x}= \text{ReLU}(W^\top I W x + b)$, with a weight matrix $W \in \mathbb{R}^{m_1 \times m_2}$ and an identity matrix $I\in \mathbb{R}^{m_1 \times m_1}$); the sum of the parameter subcomponents for that matrix found by SPD; and the individual parameter subcomponents. Although this run of SPD used 20 subcomponents, only 6 subcomponents are shown, ordered by the sum of their matrix norms. The first five have learned one direction each, each corresponding to one of the columns of the target model. The final column and the other 14 components (not shown) have a negligible norm because they are superfluous for replicating the behavior of the target model.
Figure 3: Plots of (left to right) the weights in the hidden identity matrix I of the TMS_{5-2+ID}; the sum of all subcomponents found by SPD for that matrix (including small-norm subcomponents that are not shown); and the largest three individual subcomponents. We see that SPD finds two subcomponents that together sum to the original rank-2 identity matrix of the target model, while the other subcomponents have a negligible weight norm.
Figure 4: Plots of (left to right) the TMS_{5-2+ID} networks parametrized by: The target model parameters; the sum of all parameter subcomponents found by SPD the decomposition of the model; and the seven individual subcomponents of non-negligible size. We see that SPD finds five subcomponents for the embedding matrix W, corresponding to the five input features, and two subcomponents that span the identity matrix I in the middle of the model.
In this setting, the target network is a residual MLP that was previously studied by [Braun et al., 2025]. It consists of a single residual MLP layer of width
Figure 5: The architecture of the Toy Model of Compressed Computation. It uses a 1-layer residual MLP. Figure adapted from [Braun et al., 2025].
A naive solution to this task would be to dedicate each of the
Originally, we were unsure what the ground-truth mechanisms in the model's MLP output weight matrix
To understand how each neuron participates in computing the output for a given input feature, [Braun et al., 2025] measured the neuron's contribution to each input feature computation. The neuron contributions for input features
where ${W_E}{[:,i]}, {W_U}{[i,:]}$ are the
Neuron contributions for individual subcomponents of
We apply SPD to the MLP weight matrices
Figure 6: Toy Model of Compressed Computation: Similarity between target model weights and SPD subcomponents for the first 10 (out of 100) input feature dimensions. Top: Neuron contributions measured by the model equation for each input feature index $i\in {0,\dots,9}$. Bottom: Neuron contributions for the corresponding parameter subcomponents, measured by the subcomponent equation for each input feature index $i\in {0,\dots,9}$. The neurons are numbered from 0 to 49 based on their raw position in the MLP layer.
Figure 7: Toy Model of Compressed Computation: Similarity between target model weights and SPD subcomponents for all 100 input feature dimensions. X-axis: Neuron contributions measured by the model equation for each input feature index $i\in {0,\dots,99}$. Y-axis: Neuron contributions for the corresponding parameter subcomponents of $W_{\text{in}}$, measured by the subcomponent equation for each feature index $i\in {0,\dots,99}$. There is a very close match between the X and Y axis for each neuron contribution, indicating that each subcomponent connects its corresponding feature to the MLP neurons with almost the same weights as the target model.
Meanwhile, SPD splits its corresponding
When the importance loss coefficient
Figure 8: Causal importance values of each subcomponent (clipped between 0 and 1) in response to one-hot inputs ($x_i=0.75$) for multiple importance loss coefficients $\beta_3$. The columns of each matrix are permuted differently, ordered by iteratively choosing the subcomponent (without replacement) with the highest causal importance for each input feature. When the importance loss coefficient $\beta_3$ is too low (Left), subcomponents in $W_{\text{in}}$ are not 'monosemantic' (i.e. multiple subcomponents have causal importance for the same feature). When the importance loss coefficient $\beta_3$ is just right (middle column), each subcomponent in $W_{\text{in}}$ has causal importance for computing a unique input feature. It also identifies the correct number of subcomponents in $W_{\text{out}}$ (50). Although it identifies the correct number of components, these subcomponents need not align with any particular basis, and hence look 'noisy' because they align with multiple features. But this does not matter, since they always co-activate together and sum to the target model's identity-matrix parameters. When the importance loss coefficient $\beta_3$ is too high (Right), the rank-50 $W_{\text{out}}$ component is split into too many subcomponents—approximately one subcomponent for each feature, but where many of the subcomponents have small causal importance values for other features. Also, the causal importance values on the diagonal shrink far below 1.0, resulting in high $\mathcal{L}{\text{stochastic-recon}}$ and $\mathcal{L}{\text{stochastic-recon-layerwise}}$ losses.
Realistic neural networks seem capable of implementing mechanisms distributed across more than one layer [Yun et al., 2023; Lindsey & Olah, 2024]. To study the ability of SPD to identify components that are spread over multiple layers, [Braun et al., 2025] also studied a toy model trained on the same task with the same residual MLP architecture as the one in the previous section, but with the
As in the Toy Model of Compressed Computation, the model learns to compute individual functions using multiple neurons. But here it learns to do so using neurons that are spread over two layers. SPD should find subcomponents in both
We apply SPD on this model, as well as another model that spreads
Figure 9: The architecture of one of our two Toy models of Cross-Layer Distributed representations. The other toy model has three MLP blocks instead of two. Figure adapted from [Braun et al., 2025].
SPD finds qualitatively similar results to the
In the three-layer model, the MLP input matrices
In both the two- and three-layer models, we find that the computations occurring in each parameter subcomponent of the
Figure 10: Toy Model of Distributed Representations (Two Layers): Similarity between target model weights and SPD subcomponents for all 100 input feature dimensions in a 2-layer residual MLP. Each point represents one neuron's contribution to a particular input feature. X-axis: Neuron contributions measured by the model equation. Y-axis: Neuron contributions for the same neuron on the same input feature in the corresponding parameter subcomponents of $W^1_{\text{in}},W^2_{\text{in}}$, measured by the subcomponent equation. There is a close match between the X and Y axes for each neuron contribution, indicating that each subcomponent connects its corresponding feature to the MLP neurons with similar weights as the target model. However, there is a systematic skew toward higher values on the Y-axis, indicating that the neuron contributions of the subcomponents tend to be slightly larger. This is in contrast to the one-layer case and three-layer case. We currently do not understand the source of this discrepancy, but it is possibly an outcome of suboptimal hyperparameters.
Figure 11: Toy Model of Distributed Representations (Three Layers): Similarity between target model weights and SPD subcomponents for all 102 input feature dimensions in a 3-layer residual MLP. Each point represents one neuron's contribution for a particular input feature. X-axis: Neuron contributions measured by the model equation. Y-axis: Neuron contributions for the same neuron on the same input feature in the corresponding parameter subcomponents of $W^1_{\text{in}},W^2_{\text{in}}, W^3_{\text{in}}$, measured by the subcomponent equation. There is a close match between the X and Y axes for each neuron contribution, indicating that each subcomponent connects its corresponding feature to the MLP neurons with similar weights as the target model.
Figure 12: Toy Model of Distributed Representations (Two Layers): Causal importance values of each subcomponent (clipped between 0 and 1) for the matrices in both MLP layers, $W^1_{\text{in}},W^1_{\text{out}},W^2_{\text{in}},W^2_{\text{out}}$ in response to one-hot inputs ($x_i=0.75$). Each subcomponent in $W^1_{\text{in}},W^2_{\text{in}}$ has causal importance for computing a unique input feature. On the other hand, the combined 50 subcomponents of $W^1_{\text{out}},W^2_{\text{out}}$ all coactivate for all input features, indicating they are part of a single rank-50 identity component. The columns of each matrix are permuted differently, ordered by iteratively choosing the subcomponent (without replacement) with the highest causal importance for each input feature.
Figure 13: Toy Model of Distributed Representations (Three Layers): Causal importance values of each subcomponent (clipped between 0 and 1) for the matrices in all three MLP layers, $W^1_{\text{in}},W^1_{\text{out}},W^2_{\text{in}},W^2_{\text{out}},W^3_{\text{in}},W^3_{\text{out}}$ in response to one-hot inputs ($x_i=0.75$). Each subcomponent in $W^1_{\text{in}},W^2_{\text{in}},W^3_{\text{in}}$ has causal importance for computing a unique input feature. On the other hand, the combined 51 subcomponents of $W^1_{\text{out}},W^2_{\text{out}},W^3_{\text{out}}$ all coactivate for all input features, indicating they are part of a single rank-51 identity component. The columns of each matrix are permuted differently, ordered by iteratively choosing the subcomponent (without replacement) with the highest causal importance for each input feature.
In this paper we introduced SPD, a method that resolves many of the issues of APD [Braun et al., 2025]. The method is considerably more scalable and robust to hyperparameters, which we demonstrate by using the method to decompose deeper and more complex models than APD has successfully decomposed.
We hypothesize that this relative robustness comes from various sources:
-
APD required estimating the expected number of active components in advance, because it needed to set the hyperparameter
$k$ for selecting the top-$k$ most attributed components per batch. This number would usually not be known in advance for realistic models. APD results were very sensitive to it. SPD uses trained causal importance functions instead, and therefore no longer needs to use a fixed estimate for the number of active subcomponents per datapoint. We still need to pick the loss coefficient for the causal importance penalty$\beta_3$ , but this is a much more forgiving hyperparameter than the hyper-sensitive top-$k$ hyperparameter. -
Gradients flow through every subcomponent on every datapoint, unlike in APD, where gradients only flowed through the top-$k$ most attributed components. Top-$k$ activation functions create discontinuities that, in general, tend to lead to unstable gradient-based training. In SPD, even subcomponents with causal importance values of zero will almost always permit gradients to flow, thus helping them error-correct if they are wrong.
-
SPD does not need to optimize for 'simplicity' (in the sense of [Braun et al., 2025], where simple parameter components span as few ranks and layers as possible). Not only does this remove one hyperparameter (making tuning easier), but it also avoids inducing shrinkage in the singular values of the parameter component weight matrices. SPD does not exhibit shrinkage in the parameter subcomponents because the importance norm penalizes the probability that a subcomponent will be unmasked, but does not directly penalize norms of the singular values of the parameter matrices themselves. This can be helpful for learning correct solutions: For example, if correctly-oriented parameter components exhibit shrinkage, then their sum will not sum to the parameters of the target model, and therefore other parameter components will need to compensate. For this reason, the faithfulness and simplicity losses in APD were in tension. Removing this tension, and removing a whole hyperparameter to tune, makes it easier to hit small targets in parameter space. Although there is shrinkage in the causal importance values, this does not appear to be very influential since causal importance values only determine the allowed minimum value of the masks. For example, a causal importance value of
$0.95$ indicates that we can ablate a subcomponent by up to five percent without significantly affecting the network output, but we can also just not ablate it at all. -
APD used gradient-based attribution methods to estimate causal importance, even though those methods are only first-order approximations of ideal causal attributions, which is often a poor approximation [Watson & Floridi, 2022; Kramár et al., 2024]. If causal attributions are wrong, then the wrong parameter components would activate, so the wrong parameter components would be trained to compute the model's function on particular inputs. Systematic errors in causal attributions will lead to systematic biases in the gradients, which can be catastrophic when trying to hit a very particular target in parameter space. SPD instead directly optimizes for causal importance and is likely a much better estimate than the approximations found by even somewhat sophisticated attribution methods (e.g. [Sundararajan et al., 2017]).
We hope that the stability and scalability of SPD will facilitate further scaling to much larger models than the ones studied here. In future work, we plan to test the scaling limits of the current method and explore any necessary adjustments for increased scalability and robustness.
We plan to investigate several outlying issues in future work. One issue is that we are unsure if learning independent subcomponents will enable the method to learn subcomponents that can describe the network's function using as short a description as possible. It may be possible that information from future layers is necessary to identify whether a given subcomponent is causally important. If it is, then calculating causal importance values layerwise will mean that some subcomponents are active when they need not be. It may therefore be interesting to explore causal importance functions that take as input more global information from throughout the network, rather than only the subcomponent inner activations at a given layer.
Another issue is that both APD and SPD privilege mechanisms that span individual layers due to the importance loss or simplicity loss in SPD and APD respectively; it may be desirable to identify loss functions that privilege layers less.
The toy models in our work had known ground truth mechanisms, and therefore it was straightforward to identify which subcomponents should be grouped together into full parameter components. However, in the general case we will not know this by default. We therefore need to develop approaches that cluster subcomponents together in a way that combines the sparse and dense coding schemes to achieve components that permit a minimum length description of the network's function in terms of parameter components.
It is worth noting that the SPD approach can be generalized in multiple straightforward ways. It is not necessary, for instance, to decompose parameter components strictly into subcomponents consisting of rank-one matrices. Subcomponents could, for instance, span only one rank but across all matrices in all layers. Alternatively, different implementations of the causal importance function could be used.
We expect that SPD's scalability and stability will enable new research directions previously inaccessible with APD, such as investigating mechanisms of memorization and their relationship to neural network parameter storage capacity. Additionally, the principles behind SPD may be useful for training intrinsically decomposed models.
SPD uses randomly sampled masks in order to identify the minimal set of parameter subcomponents that are causally important for computing a model's output on a given input. This is, in essence, a method for finding a good set of causal mediators [Mueller et al., 2024; Vig et al., 2020; Geiger et al., 2021; Geiger et al., 2024]: Our causal importance function can be thought of as learning how to causally intervene on our target model in order to identify a minimal set of simple causal mediators of a model's computation. However, unlike many approaches in the causal intervention literature, our approach does not assume a particular basis for these interventions. Instead, it learns the basis in which causal interventions are made. It also makes these interventions in parameter space.
Some previous work learns fixed masks to ablate parts of the model's input or parts of its parameters. But these masks tend to be for fixed inputs, and also often assume a particular basis. Our work also has an important difference to other masking-based attribution approaches: While our causal importances are predicted, the masks themselves are stochastically sampled with an amount of randomness based on the predicted causal importances.
Our definition of causal importance of subcomponents is related but not identical to the definition of causal dependence used in other literature [Lewis, 1973; Mueller et al., 2024]. For one, it may be possible that networks compute particular outputs even if one causally important subcomponent is ablated, using emergent self-repair [McGrath et al., 2023]. In that case, a subcomponent may be causally important for a particular output, but the output may not be causally dependent on it.
SPD also has parallels to work that uses attribution methods to approximate the causal importance of model components or the inputs (as in saliency maps) [Vig et al., 2020]. Our approach learns to predict causal importances given an input, which can be thought of as learning to predict attributions. To the best of our knowledge, we are not aware of similar approaches that learn to predict the input-dependent ablatability of model components using the method described in this paper.
[Chrisman et al., 2025] decomposed networks in parameter space by finding low-rank parameter components that can reconstruct, using sparse coefficients, the gradient of a loss between the network output and a baseline output. Somewhat similarly, [Matena et al., 2025] decomposed models in parameter space via non-negative factorisation of the models' per-sample Fisher Information matrices into components. SPD also decomposes networks into low-rank components in parameter space based on how the network output responds to perturbations. But instead of relying on local approximations like gradients to estimate the effects of a perturbation, it is trained by directly checking the effect ablating components in various combinations has on the output.
SPD can be viewed as approximately quantifying the degeneracy in neural network weights over different subdistributions of the data. SPD was in part inspired by singular learning theory [Watanabe, 2009], which quantifies degeneracy in network weights present over the entire data distribution using the learning coefficient. [Wang et al., 2024] defined the data-refined learning coefficient, which measures degeneracy in neural network weights over a chosen subset of the distribution. In contrast, SPD works in an unsupervised manner, finding a single set of vectors to represent the neural network parameters, such that as many vectors as possible are degenerate on any given datapoint in the distribution. SPD also requires vectors to be ablatable to zero rather than just being degenerate locally.
See [Braun et al., 2025] for further discussion on the relationship between linear parameter decomposition methods and sparse autoencoders; transcoders; weight masking and pruning; circuit discovery and causal mediation analysis; interpretability of neural network parameters; mixture of experts; and loss landscape intrinsic dimensionality and degeneracy.
We thank Tom McGrath, Stefan Heimersheim, Daniel Filan, Bart Bussmann, Logan Smith, Nathan Hu, Dashiell Stander, Brianna Chrisman, Kola Ayonrinde, and Atticus Geiger for helpful feedback on previous drafts of this paper. We also thank Stefan Heimersheim for initially suggesting the idea to train on an early version of what became the stochastic loss, which we had up to then only treated as a validation metric. Additionally, we thank Kaarel Hänni, whose inputs helped us develop the definition of component causal importance, and Linda Linsefors, whose corrections of earlier work on interference terms arising in computation in superposition helped us interpret the toy model of compressed computation.
The MLPs
Here,
Note that there is no sum over
It is important to note that this choice of causal importance function is only one of many possibilities. We chose it for its relative simplicity and low cost. In theory, SPD should be compatible with any method of predicting causal importance values
The flat regions in a hard sigmoid function can lead to dead gradients for inputs below
We use the lower leaky hard sigmoid for the forward pass because we should usually be able to scale a subcomponent that does not influence the output of the model below zero, but we cannot scale a subcomponent that does influence the output of the model above
We use the upper leaky hard sigmoid in the importance loss because the causal importance values cannot be allowed to take negative values, else the training would not be incentivized to sparsify them. But there is no issue with allowing causal importance function outputs to be greater than
Here we list some heuristics for how to select hyperparameters to balance the different loss terms of SPD. In particular, we focus on how to determine the appropriate trade-off between the stochastic reconstruction losses $\mathcal{L}{\text{stochastic-recon}}$ and $\mathcal{L}{\text{stochastic-recon-layerwise}}$ (controlled by
-
Negligible performance loss: The performance difference between the SPD model and the original model on the training dataset should be small enough to be mostly negligible. Quantitatively, one might want to judge how large the performance drop is based on LM scaling curves, as [Gao et al., 2024] suggested for judging the quality of SAE reconstructions.
-
Noise from superposition: Inactive mechanisms in superposition can still contribute to the model output through small interference terms [Hänni et al., 2024]. We can estimate the expected size of such terms, and ensure that $\mathcal{L}{\text{stochastic-recon}}, \mathcal{L}{\text{stochastic-recon-layerwise}}$ are no larger than what could plausibly be caused by such noise terms.
-
Recovering known mechanisms: If some of the ground truth mechanisms in the target model are already known, we can restrict hyperparameters such that they recover those mechanisms. For example, in a language model, the embedding matrix mechanisms are usually known: Each vocabulary element should be assigned one mechanism. If none of a model's mechanisms are known to start with, we could insert known mechanisms into it. For example, one might insert an identity matrix at some layer in an LLM and check whether the SPD decomposition recovers it as a single high-rank component, as in the TMS-with-identity model.
-
Other sanity checks: The decomposition should pass other sanity checks. For example, a model parametrized by the sum of unmasked subcomponents should recover the performance of the target model. And at least some of the causal importance values should take values of
$1$ for most inputs; if they do not, then it is likely that the importance minimality loss coefficient$\beta_3$ is too high.
Target model training
All target models were trained for 10k steps using the AdamW optimizer [Loshchilov & Hutter, 2019] with weight decay 0.01 and constant learning rate
- $\text{TMS}{5-2}$ and $\text{TMS}{5-2+\text{ID}}$: batch size 1024
- $\text{TMS}{40-10}$ and $\text{TMS}{40-10+\text{ID}}$: batch size 8192
SPD training: Common hyperparameters
- Optimizer: Adam with max learning rate
$1\times 10^{-3}$ and cosine learning rate schedule - Training: 40k steps, batch size 4096
- Data distribution: same as target model (feature probability 0.05)
- Stochastic sampling:
$S=1$ for $\mathcal{L}{\text{stochastic-recon}}$ and $\mathcal{L}{\text{stochastic-recon-layerwise}}$ - Loss coefficients: $\mathcal{L}{\text{faithfulness}}=1$, $\mathcal{L}{\text{stochastic-recon}}=1$,
$\mathcal{L}_{\text{stochastic-recon-layerwise}}=1$ - Causal importance functions: One MLP per subcomponent, each with one hidden layer of
$d_{\text{gate}}=16$ $\text{GELU}$ neurons.
SPD training: Model-specific hyperparameters
- $\text{TMS}{5-2}$ and $\text{TMS}{5-2+\text{ID}}$:
$\mathcal{L}_{\text{importance-minimality}}$ coefficient$3\times10^{-3}$ ,$p=1$ - $\text{TMS}{40-10}$ and $\text{TMS}{40-10+\text{ID}}$:
$\mathcal{L}_{\text{importance-minimality}}$ coefficient$1\times10^{-4}$ ,$p=2$
Model architectures
- 1-layer and 2-layer residual MLPs: 100 input features, embedding dimension 1000, 50 MLP neurons total (25 per layer for 2-layer)
- 3-layer residual MLP: 102 input features, embedding dimension 1000, 51 MLP neurons total (17 per layer)
Target model training
All models trained using AdamW [Loshchilov & Hutter, 2019] with weight decay 0.01, max learning rate
SPD training: Common hyperparameters
- Optimizer: Adam with constant learning rate
- Batch size: 2048
- Data distribution: same as target model (feature probability 0.01)
- Stochastic sampling:
$S=1$ for both stochastic losses - Loss coefficients: $\mathcal{L}{\text{stochastic-recon}}=1$, $\mathcal{L}{\text{stochastic-recon-layerwise}}=1$
-
$p=2$ for$\mathcal{L}_{\text{importance-minimality}}$
SPD training: Model-specific hyperparameters
- 1-layer residual MLP: learning rate
$2\times10^{-3}$ , $\mathcal{L}{\text{importance-minimality}}$ coefficient $1\times10^{-5}$, $C=100$ initial subcomponents, 30k training steps, causal importance function with $d{\text{gate}}=16$ hidden neurons per subcomponent - 2-layer residual MLP: learning rate
$1\times10^{-3}$ , $\mathcal{L}{\text{importance-minimality}}$ coefficient $1\times10^{-5}$, $C=400$ initial subcomponents, 50k training steps, causal importance function with $d{\text{gate}}=16$ hidden$\text{GELU}$ neurons per subcomponent - 3-layer residual MLP: learning rate
$1\times10^{-3}$ , $\mathcal{L}{\text{importance-minimality}}$ coefficient $0.5\times10^{-5}$, $C=500$ initial subcomponents, 200k training steps (converges around 70k), causal importance function with $d{\text{gate}}=128$ hidden$\text{GELU}$ neurons per subcomponent
Algorithm: Stochastic Parameter Decomposition (SPD)
Input:
- Target model
$f(\cdot, W)$ with parameters$W={W^l}_{l=1}^L$ to be decomposed - Dataset
$\mathcal{D}$ -
$C$ subcomponents per layer - Loss coefficients
$\beta_1, \beta_2, \beta_3$ - Causal Importance Minimality loss p-norm
$p > 0$ - Number of mask samples
$S \ge 1$
Output:
- Learned subcomponents ${U^l, V^l}{l=1}^L$ and parameters for causal importance MLP ${\Gamma^l}{l=1}^L$
Algorithm:
- Initialize subcomponents
$U^l \in \mathbb{R}^{d_{out}\times C}$ ,$V^l \in \mathbb{R}^{C\times d_{in}}$ for each layer$l$ - Initialize parameters for causal importance MLPs
$\Gamma^l$ - Initialize an optimizer for all trainable parameters (
${U^l, V^l, \Gamma^l}_{l=1}^L$ )
For each training step:
- Sample a data batch
$X = {x_1, \dots, x_B}$ from$\mathcal{D}$ - Compute target model outputs
$Y_{\text{target}}$ and pre-weight activations${a^l}_{l=1}^L$ for batch$X$ - $\mathcal{L}{\text{faithfulness}} \gets \frac{1}{N} \sum{l=1}^L |W^l - U^l V^l|_F^2$
For each layer
-
$h^l \gets V^l a^l$ // Inner activations -
$G^l_{\text{raw}} \gets \Gamma^l(h^l)$ // Raw causal importance MLP outputs -
$\mathcal{L}{\text{importance-minimality}} \gets \frac{1}{B} \sum{b=1}^B \sum_{l=1}^L \sum_{c=1}^C |\sigma_{H,\text{upper}}(G^{l}_{\text{raw}, b, c})|^p$
-
$\mathcal{L}_{\text{stochastic-recon}} \gets 0$ -
$\mathcal{L}_{\text{stochastic-recon-layerwise}} \gets 0$ -
Let
$G^l = \sigma_{H,\text{lower}}(G^l_{\text{raw}})$ for all$l$
For
- Sample
$R_s^l \sim \mathcal{U}(0,1)^{B \times C}$ for each layer$l$ - Compute masks
$M_s^l \gets G^l + (1 - G^l) \odot R_s^l$ - Construct masked weights ${W'^{(s)}b}{b=1}^B$ with $W'^{(s), l}b = U^l \cdot \text{Diag}(M{s, b}^{l}) \cdot V^l$
$Y_{\text{masked}} \gets f(X, W'^{(s)})$ - $\mathcal{L}{\text{stochastic-recon}} \gets \mathcal{L}{\text{stochastic-recon}} + D(Y_{\text{masked}}, Y_{\text{target}})$
For layer
-
Construct layerwise masked weights ${W'^{(s,l')}b}{b=1}^B$ as follows:
- $W'^{(s,l'), l}b = U^l \cdot \text{Diag}(M{s,b}^{l} \text{ if } l = l' \text{ else } \mathbf{1}) \cdot V^l$
-
$Y_{\text{masked-layerwise}} \gets f(X, W'^{(s, l')})$ -
$\mathcal{L}{\text{stochastic-recon-layerwise}} \gets \mathcal{L}{\text{stochastic-recon-layerwise}} + D(Y_{\text{masked-layerwise}}, Y_{\text{target}})$
-
Normalize $\mathcal{L}{\text{stochastic-recon}} \gets \mathcal{L}{\text{stochastic-recon}} / S$
-
Normalize $\mathcal{L}{\text{stochastic-recon-layerwise}} \gets \mathcal{L}{\text{stochastic-recon-layerwise}} / (S \cdot L)$
-
$\mathcal{L}{\text{SPD}} \gets \mathcal{L}{\text{faithfulness}} + \beta_1 \mathcal{L}{\text{stochastic-recon}} + \beta_2 \mathcal{L}{\text{stochastic-recon-layerwise}} + \beta_3 \mathcal{L}_{\text{importance-minimality}}$
-
Update parameters of ${U^l, V^l, \Gamma^l}{l=1}^L$ using gradients of $\mathcal{L}{\text{SPD}}$
Adebayo, J., Gilmer, J., Muelly, M., Goodfellow, I., Hardt, M., & Kim, B. (2020). Sanity checks for saliency maps. Advances in Neural Information Processing Systems, 33, 9505-9515.
Braun, D., Bushnaq, L., & Sharkey, L. (2025). Interpretability in parameter space: Minimizing the number of active parameter components. arXiv preprint arXiv:2401.14446.
Bricken, T., Templeton, A., Batson, J., Chen, B., Jermyn, A., Conerly, T., ... & Olah, C. (2023). Towards monosemanticity: Decomposing language models with dictionary learning. Transformer Circuits Thread.
Cao, Y., Chen, K., & Rudin, C. (2021). Low-complexity probing via finding subnetworks. arXiv preprint arXiv:2102.06979.
Chanin, D., Templeton, A., & Olah, C. (2024). Absorption: Studying feature splitting in sparse autoencoders. Transformer Circuits Thread.
Chrisman, A., et al. (2025). Identifying sparsely active circuits in neural networks. arXiv preprint.
Csordás, R., van Steenkiste, S., & Schmidhuber, J. (2021). Are neural nets modular? Inspecting functional modularity through differentiable weight masks. arXiv preprint arXiv:2103.13977.
Cunningham, H., Ewart, A., Riggs, L., Huben, R., & Sharkey, L. (2023). Sparse autoencoders find highly interpretable features in language models. arXiv preprint arXiv:2309.08600.
Elhage, N., Hume, T., Olsson, C., Schiefer, N., Henighan, T., Kravec, S., ... & Olah, C. (2022). Toy models of superposition. arXiv preprint arXiv:2209.10652.
Gao, L., et al. (2024). Scaling and evaluating sparse autoencoders. arXiv preprint.
Geiger, A., Lu, H., Icard, T., & Potts, C. (2021). Causal abstractions of neural networks. Advances in Neural Information Processing Systems, 34, 9578-9589.
Geiger, A., Carstensen, A., Frank, M., & Potts, C. (2024). Finding alignments between interpretable causal variables and distributed neural representations. arXiv preprint arXiv:2401.14446.
Hänni, K., Mendel, J., Vaintrob, D., & Chan, L. (2024). Mathematical models of computation in superposition. arXiv preprint arXiv:2408.05451.
Jermyn, A. (2024). Tanh activations in sparse autoencoders. Transformer Circuits Thread.
Kingma, D. P., & Welling, M. (2013). Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114.
Kramár, J., Khan, S., McAleer, S., & Steinhardt, J. (2024). ATP: An efficient and scalable method for localizing model behavior. arXiv preprint arXiv:2401.14446.
Leask, C., Templeton, A., & Olah, C. (2025). Sparse autoencoders: Canonical units and feature splitting. Transformer Circuits Thread.
Lewis, D. (1973). Causation. The Journal of Philosophy, 70(17), 556-567.
Lindsey, J., & Olah, C. (2024). Cross-coders: Interpretable feature discovery in language models. Transformer Circuits Thread.
Loshchilov, I., & Hutter, F. (2019). Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101.
Matena, M., et al. (2025). Uncovering model processing strategies via parameter space analysis. arXiv preprint.
McGrath, T., Kapishnikov, A., Tomašev, N., Pearce, A., Wattenberg, M., Hassabis, D., ... & Kramnik, V. (2023). The hydra effect: Emergent self-repair in language models. arXiv preprint arXiv:2301.13793.
Mendel, J. (2024). SAE: Sparse autoencoders for interpretability. Transformer Circuits Thread.
Mueller, S. (2024). Missed causes and ambiguous effects: Challenges in causal mediation analysis. arXiv preprint arXiv:2401.14446.
Mueller, S. (2024). The quest for the right mediator: A history of causal mediation analysis. arXiv preprint arXiv:2401.14446.
Sharkey, L., Braun, D., & Millidge, B. (2022). Taking features out of superposition with sparse autoencoders. Alignment Forum.
Sundararajan, M., Taly, A., & Yan, Q. (2017). Axiomatic attribution for deep networks. International Conference on Machine Learning, 3319-3328.
Syed, S., Cooper, A. F., & Steinhardt, J. (2023). Attribution patching outperforms automated circuit discovery. arXiv preprint arXiv:2301.13793.
Till, J., Templeton, A., & Olah, C. (2024). True features: A framework for understanding feature splitting in sparse autoencoders. Transformer Circuits Thread.
Vig, J., Gehrmann, S., Belinkov, Y., Qian, S., Nevo, D., Singer, Y., & Shieber, S. (2020). Causal mediation analysis for interpreting neural NLP: The case of gender bias. arXiv preprint arXiv:2004.12265.
Wang, K., et al. (2024). Differentiation and specialization of attention heads. arXiv preprint.
Watanabe, S. (2009). Algebraic geometry and statistical learning theory. Cambridge University Press.
Watson, D., & Floridi, L. (2022). The explanation game: A formal framework for interpretable machine learning. Synthese, 200(2), 1-32.
Wright, B., & Sharkey, L. (2024). Addressing feature suppression in SAEs. Alignment Forum.
Yun, Z., Chen, Y., Olshausen, B. A., & LeCun, Y. (2023). Transformer visualization via dictionary learning: Contextualized embedding as a linear superposition of transformer factors. arXiv preprint arXiv:2103.15949.
Zhang, C., Bengio, S., Hardt, M., Recht, B., & Vinyals, O. (2021). Understanding deep learning requires rethinking generalization. Communications of the ACM, 64(3), 107-115.
Zhang, C., Bengio, S., Hardt, M., Recht, B., & Vinyals, O. (2021). Subnetwork structure: The key to out-of-distribution generalization. arXiv preprint arXiv:2103.13977.












