We recently put out a new paper on a scalable generalization of influence functions, which quantify how training data affects model behavior (see Nina’s post). I’m excited about this because it takes a completely new methodological approach to measuring influence.
Instead of relying on a Hessian inverse (which is ill-defined and expensive), our new “Bayesian” influence functions (BIF) rely on a covariance calculation (which can be scalably estimated with MCMC). This approach is more theoretically sound (no more Hessian inverses), and it achieves what I think are a more desirable set of engineering tradeoffs (better model-size scaling but worse dataset-size scaling).
At Timaeus, we think these kinds of techniques are on the critical path to safety. Modern alignment techniques like RLHF and Constitutional AI are about controlling model behavior by selecting the right training data. If this continues to be the case, we will need better tools for understanding and steering the pipeline from data to behavior.
It’s still early days for the BIF. We’ve done some initial validation on retraining benchmarks and other quantitative tests (follow-up work coming soon), where the BIF comes out looking strong, but more work will be needed to understand the full set of costs and benefits. As that foundation gets established, we expect we’ll be able to start applying these techniques directly to safety-relevant problems.
How does training data shape model behavior? Well, it’s complicated…
But we can make progress by studying a simplified, linear version of the mapping from data to behavior. This is the idea behind influence functions (IF), which are one of the main pillars in modern training data attribution.
Unfortunately, classical influence functions are fundamentally limited:
Theoretically, IFs assume a unique, isolated global minimum. This is never true for NNs.
Practically, the Hessian dependency poses a severe memory bottleneck that explodes with model size
Still, there’s hope. The Bayesian influence function (BIF) addresses both issues. The idea:
Study influence not on a single minimum but on the distribution of low-loss solutions.
Skip the Hessian inversion. Compute covariance over this distribution.
At first glance, this looks like a step backwards: computing a covariance over the full Bayesian posterior is much more intractable than computing the Hessian! And we typically care about influence for a specific checkpoint, not aggregated over all possible solutions.
In our new paper, we solve both problems by introducing:
A local version of the BIF that applies to individual NN checkpoints.
A scalable stochastic-gradient MCMC estimator.
The local BIF bypasses the Hessian bottleneck and is well-defined even for degenerate models. It can be batched and scales to billions of parameters. One of the best perks is that we get fine-grained per-token influence functions for no extra compute cost.
To validate the BIF, we test it on a standard retraining benchmark, via the Linear Datamodeling Score (LDS). We find that it is competitive with leading IF-based approximations, especially in the small-dataset regime.
There are caveats: the BIF exhibits worse scaling with dataset size, we’re still in the early days of understanding the role of SGMCMC hyperparameters, and generally more investigation is needed!
But we see straightforward ways to make progress on these problems.
I’m pretty excited about building tools/methods for better dataset influence understanding, so this intuitively seems pretty exciting! (I’m both interested in better cheap approximation of the effects of leaving some data out and the effects of adding some data in.)
(I haven’t looked at the exact method and results in this paper yet.)
Great to see more work on (better) influence functions!
Lots of interesting things to discuss here[1], but one thing I would like to highlight is that classical IFs indeed arise when you do the usual implicit function theorem + global minimum assumption (which is obviously violated in the context of DL), but they also arise as the limit of unrolling as t→∞. What follows will be more of theoretical nature summarizing statements in Mlodozeniec et al.
Influence functions suffer from another shortcoming, since they only use final weights (as you are aware). So you might say that we shouldn’t do influence functions, but track a different counterfactual: The counterfactual over training “What if I added/removed a sample zm at time step t”. To do this, you can consider each SGD training step θt→θt+1 (or more generally some optimizer like Adam), and approximate the Jacobian of that map, i.e. θt+1≈θt+At⋅(θt+1−θt). Doing some calculus you end up with At=I−λt⋅Ht, where λt is the lr and Ht
You can use this linear approximation of training steps to compute a new counterfactual (Eq. 57 in Mlodozeniec et al.) . This can be formalized as a pair (θt,rt) of the weights θt and the response rt which captures the counterfactual, i.e. θ′t(ϵ)≈θt+ϵ⋅rt, where θ′t(ϵ) is the counterfactual of adding the data point with weighting ϵ at time step t. Ok, without further ado, here is the result (Theorem 2 in Mlodozeniec et al.):
Under some assumptions on SGD (A1-A6 in the paper) as you continue training t→∞, you get an a.s. convergence (θt,rt)→(θ∞,r∞) where θ∞ is a local minimum or a saddle point. Assume it is a local minimum, what is the optimal response r∞? It’s our beloved (pseudo-)inverse Hessian vector product (IVHP) from classical IFs, well… up to directions in weight space which are in the kernel of the Hessian.
So to summarize, the upshot is that influence functions actually can be valid beyond the original statistical setup, if (1) We model training dynamics linearly (2) We believe the assumptions A1-A6 + that we end up in a local minimum eventually, (3) We care about the behaviour limit. These assumptions can and should be debated, but I find them more reasonable and interesting than the global minimum assumption.
And as a cherry on the top, Theorem 3 shows that if you want go from the Bayesian posterior p(w∣D) to the epsilon perturbed p(w∣Dϵ) , you can again use IFs: Sampling from the perturbed distribution is approximated by sampling from the original distribution and adding the IF IVHP. Amongst linear approximations this one (in a specific sense, in the low temperature limit) is optimal for the KL divergence.[3]
More generally, I think this paper makes an important point that goes beyond any of these technical details above: We want our counterfactual estimations to be more robust against randomness in the training, but that’s for another time.
We recently put out a new paper on a scalable generalization of influence functions, which quantify how training data affects model behavior (see Nina’s post). I’m excited about this because it takes a completely new methodological approach to measuring influence.
Instead of relying on a Hessian inverse (which is ill-defined and expensive), our new “Bayesian” influence functions (BIF) rely on a covariance calculation (which can be scalably estimated with MCMC). This approach is more theoretically sound (no more Hessian inverses), and it achieves what I think are a more desirable set of engineering tradeoffs (better model-size scaling but worse dataset-size scaling).
At Timaeus, we think these kinds of techniques are on the critical path to safety. Modern alignment techniques like RLHF and Constitutional AI are about controlling model behavior by selecting the right training data. If this continues to be the case, we will need better tools for understanding and steering the pipeline from data to behavior.
It’s still early days for the BIF. We’ve done some initial validation on retraining benchmarks and other quantitative tests (follow-up work coming soon), where the BIF comes out looking strong, but more work will be needed to understand the full set of costs and benefits. As that foundation gets established, we expect we’ll be able to start applying these techniques directly to safety-relevant problems.
You can read the full announcement thread on X (reproduced below):
I’m pretty excited about building tools/methods for better dataset influence understanding, so this intuitively seems pretty exciting! (I’m both interested in better cheap approximation of the effects of leaving some data out and the effects of adding some data in.)
(I haven’t looked at the exact method and results in this paper yet.)
Great to see more work on (better) influence functions!
Lots of interesting things to discuss here[1], but one thing I would like to highlight is that classical IFs indeed arise when you do the usual implicit function theorem + global minimum assumption (which is obviously violated in the context of DL), but they also arise as the limit of unrolling as t→∞. What follows will be more of theoretical nature summarizing statements in Mlodozeniec et al.
Influence functions suffer from another shortcoming, since they only use final weights (as you are aware). So you might say that we shouldn’t do influence functions, but track a different counterfactual: The counterfactual over training “What if I added/removed a sample zm at time step t”. To do this, you can consider each SGD training step θt→θt+1 (or more generally some optimizer like Adam), and approximate the Jacobian of that map, i.e. θt+1≈θt+At⋅(θt+1−θt). Doing some calculus you end up with At=I−λt⋅Ht, where λt is the lr and Ht
the mini-batch Hessian at time step t.[2]
You can use this linear approximation of training steps to compute a new counterfactual (Eq. 57 in Mlodozeniec et al.) . This can be formalized as a pair (θt,rt) of the weights θt and the response rt which captures the counterfactual, i.e. θ′t(ϵ)≈θt+ϵ⋅rt, where θ′t(ϵ) is the counterfactual of adding the data point with weighting ϵ at time step t. Ok, without further ado, here is the result (Theorem 2 in Mlodozeniec et al.):
Under some assumptions on SGD (A1-A6 in the paper) as you continue training t→∞, you get an a.s. convergence (θt,rt)→(θ∞,r∞) where θ∞ is a local minimum or a saddle point. Assume it is a local minimum, what is the optimal response r∞? It’s our beloved (pseudo-)inverse Hessian vector product (IVHP) from classical IFs, well… up to directions in weight space which are in the kernel of the Hessian.
So to summarize, the upshot is that influence functions actually can be valid beyond the original statistical setup, if (1) We model training dynamics linearly (2) We believe the assumptions A1-A6 + that we end up in a local minimum eventually, (3) We care about the behaviour limit. These assumptions can and should be debated, but I find them more reasonable and interesting than the global minimum assumption.
And as a cherry on the top, Theorem 3 shows that if you want go from the Bayesian posterior p(w∣D) to the epsilon perturbed p(w∣Dϵ) , you can again use IFs: Sampling from the perturbed distribution is approximated by sampling from the original distribution and adding the IF IVHP. Amongst linear approximations this one (in a specific sense, in the low temperature limit) is optimal for the KL divergence.[3]
More generally, I think this paper makes an important point that goes beyond any of these technical details above: We want our counterfactual estimations to be more robust against randomness in the training, but that’s for another time.
e.g. I am not sure if agree regarding the dataset vs model size tradeoff, but maybe we have slightly different applications in mind :)
Small upshot here is that we get a natural damping which mitigates degeneracy of the Hessian
I would be curious to understand how this compares to the relationship you present in appendix A