RL with KL penalties is better seen as Bayesian inference

This blog post is largely based on an EMNLP paper with Ethan Perez and Chris Buckley. It also benefited from discussions with and comments from Hady Elsahar, Germán Kruszewski, Marc Dymetman and Jérémy Scheurer.

TLDR: KL-regularised RL, widely used as part of RL from human feedback (RLHF), is equivalent to variational inference: approximating a Bayesian posterior which specifies how to update a prior LM to conform with evidence provided by the reward function. The Bayesian perspective makes it clear that KL penalties aren’t a hack; they have a principled justification. It also nicely separates the modelling problem (defining a target distribution specifying the desired behaviour of an LM) and the inference problem (approximating that target distribution). Finally, it suggests that RL is not a good formal framework for thinking about LM alignment.

Introduction

Large language models (LMs) tend to generate outputs that reflect undesirable features of their training data such as offensiveness, social bias, harmfulness or dishonesty. Correcting these biases and constraining LMs to be honest, helpful and harmless is an essential part of the problem of aligning LMs with human preferences (henceforth “LM alignment”). One intuitive approach to LM alignment is reinforcement learning (RL): capturing human preferences as a reward function and training the LM to maximise the reward expected under LM distribution. A practical recipe for implementing this idea is RL from human feedback (RLHF): first, a reward model is trained to predict which of two texts a human prefers and then a pretrained LM is fine-tuned to maximise reward given by the reward model while being penalised for Kullback-Leibler (KL) divergence from its initial distribution. However, despite immense popularity of RLHF, the motivation for this KL penalty is not widely understood.

In this blog post, we discuss an underappreciated perspective on KL-regularised RL — the objective employed by RLHF for aligning LMs — which explains its empirical success. We start with describing a problem that arises from naively applying the standard RL objective: distribution collapse. The optimal policy under the RL objective would be a minimal-entropy LM generating a small set of sequences that obtain the highest reward. Then, we discuss how KL-regularised RL avoids distribution collapse due to its KL penalty. This constraint, we argue, transforms the problem from RL to Bayesian inference: updating a prior to conform with evidence provided by the reward.

Moreover, KL-regularised RL is equivalent to a well-studied approach to solving this inference problem approximately: variational inference. This Bayesian perspective explains how KL-regularised RL avoids the distribution collapse problem and offers a first-principles derivation for its objective. It also moves KL-regularised RL closer to other divergence-minimisation-based approaches to fine-tuning LMs such as GDC, which is not equivalent to RL and naturally avoid the distribution collapse problem. In contrast, RL avoids distribution collapse only with a particular choice of function that make it equivalent to Bayesian inference. This suggests that RL might not be an adequate formal framework for problems such as LM alignment.

Aligning language models via standard RL

Let be the set of sequences of tokens from some vocabulary. An LM can be seen as a probability distribution over . While most modern LMs are autoregressive, for simplicity we will only talk about full sequences, e.g. denotes the probability of a sequence . Similarly, a reward function assigns sequences with scalar rewards . In the context of LM alignment, represents human preferences we want to be aligned with, e.g. a non-offensiveness reward would assign low values to sequences that are offensive.

If is our parametric LM (with parameters ), the RL objective for aligning it with our reward function is just the reward expected under LM distribution:

Intuitively, maximising means sampling a number of sequences from the LM and rewarding the LM for good sequences and penalising for bad ones (e.g. offensive sentences). This approach to LM alignment is appealing in several ways, especially when compared with the standard self-supervised language modelling objective of predicting the next token in a static dataset. Because the samples come from the LM itself (as opposed to a static dataset), the sampling distribution naturally follows what the LM has already learned and the reward is only evaluated on LM’s current best guesses about the correct behaviour. For instance, assume the reward is non-offensiveness and this reward involves, but is not limited to, avoiding curse word. Then, the LM could quickly learn to avoid curses and then focus on avoiding more elaborate forms of toxicity, wasting no time on containing curse words.

The problem with the RL objective is that it treats the LM as a policy, not as a generative model. While a generative model is supposed to capture a diverse distribution of samples, a policy is supposed to chose the optimal action. In the LM context, where we don’t have a notion of state, the RL objective reduces to searching for , the sequence with highest reward. If there is one, the optimal policy is a degenerate, deterministic generative model that puts entire probability mass on that single sequence:

where is a Dirac delta distribution centred on . If there are multiple optimal sequences , probability mass would be put only on them.

This failure mode is not purely theoretical. Empirically, distribution collapse induced by maximising reward manifests as decreased fluency and diversity of samples from the LM, which can be measured in terms of perplexity, entropy and the frequency of repetitions. Degeneration of this kind was observed in multiple language generation tasks ranging from translation, summarisation, story generation, video captioning, dialogue, to code generation and LM debiasing.

Figure 1: Samples from an LM fine-tuned using  with reward  if  contains the word “Paris”,  otherwise. Even though there are infinitely many sentences containing “Paris” and the LM is not rewarded for multiple mentions of “Paris”, it still converges to a very low-entropy policy mentioning Paris as often as possible, just in case. Figure adapted from Khalifa et al., 2021.

Figure 1: Samples from an LM fine-tuned using with reward if contains the word “Paris”, otherwise. Even though there are infinitely many sentences containing “Paris” and the LM is not rewarded for multiple mentions of “Paris”, it still converges to a very low-entropy policy mentioning Paris as often as possible, just in case. Figure adapted from Khalifa et al., 2021.

While the degeneration problem is exacerbated by RL failure modes such as insufficient exploration or reward hacking, it is distinct from exploration-exploitation trade-off or reward misspecification. Even with perfect exploration (if we sampled sequences uniformly from as opposed to sampling from ), the optimal policy will still put all probability mass on . Similarly, even if is a smooth, real-valued function and it perfectly captures human preferences across the whole space of possible sequences and if is truly the best thing, we still wouldn’t want the LM to generate only . Essentially, the distribution collapse problem arises from the fact that the RL objective for LM alignment is flawed: it doesn’t care about preserving distributional properties of an LM and will always penalise the LM for putting any probability mass on non-optimal sequences until the LM collapses into a degenerate distribution.

Fine-tuning language models via KL-regularised RL

Couldn’t we somehow include preserving distributional properties of an LM as part of the reward function? The notion of preserving distributional properties of an LM can be formalised as penalising for Kullback-Leibler (KL) divergence between and some other, pretrained LM (e.g. publicly available GPT2). Typically, is initialised to and then fine-tuned to maximise the following objective:

where the KL is defined as

The first term in is equivalent to while the second additionally constrains to stay close (in terms of KL) to . Almost always some reward needs to be sacrificed for that; the coefficient determines the trade-off of how much reward is needed to justify departing from by a certain distance. This objective is commonly used as part of a popular recipe for fine-tuning LMs termed “RL from Human Feedback” (RLHF) and works surprisingly well in practice.

can easily be reformulated as just expected reward, the standard RL objective. We only have to define a new reward function which incorporates the original reward and the KL penalty, using the definition of KL divergence:

where

This new reward function additionally rewards sequences likely under (therefore fluent) and unlikely under itself (an entropy bonus). But even in this formulation, is not a standard RL objective: now the reward depends on policy parameters , which makes it non-stationary and coupled with . But is framing the maximisation of as RL really necessary? In the next section, we will develop an alternative view of this objective—as an approximate solution to a Bayesian inference problem—and argue that it is more appealing than the RL framing.

KL-regularised RL as variational inference

Aligning a pretrained LM with preferences encoded by a reward function is essentially a Bayesian inference problem. Intuitively, Bayesian inference is the problem updating a distribution to conform with new evidence. Given the prior probability of a hypothesis and likelihood of evidence assuming , the posterior probability of is given by the Bayes’ theorem: . In our setting, we’re updating , which is initially equal to a prior to conform with evidence provided by the assumption that is optimal in terms of . A reward function can be represented as a distribution over that makes high-reward sequences more likely than low-reward sequences. A simple way of doing that is exponentiating the reward and then rescaling it to be a normalised probability distribution. Then, the posterior is given by:

where is the prior, is the evidence provided by the reward function (scaled by temperature ) and is a constant ensuring that **is a normalised probability distribution. represents a version updated to account for the reward . It also happens to coincide with the optimal policy for :

Moreover, the KL-regularised RL objective can be cast as minimising the KL divergence between the LM and this target distribution :

That’s a different KL than the KL penalty term we’ve seen before. Minimising this new KL is equivalent to variational inference, a well-known approach to approximating Bayesian inference. More formally, is the evidence lower bound (ELBO) on the log likelihood of being optimal under , assuming a prior . Minimising this bound makes approximate the true posterior . A derivation of these equalities can be found in the appendix below.

Why is this picture insightful? For one, it explains where the KL penalty term in KL-regularised RL’s original objective comes from. It is necessary to transform the problem from RL to minimising a divergence from a target distribution . This in turn makes the distributional character of an LM a first-class citizen which explains why KL-regularised RL is able to maintain the fluency and diversity of the original LM .

Separation of modelling and inference

In the last section, we have argued that KL-regularised RL is secretly variational inference and that this vantage points elegantly explains why it works. Here, we explore a different advantage of the Bayesian perspective. Essentially, what it says is that aligning an LM with human preferences is a two-step process:

  1. First, you define a distribution specifying the desired behaviour of your LM. A principled way of doing that is using Bayes’ rule to define a posterior like ,

  2. Second, you figure out how to sample from your posterior.

These two steps roughly correspond to a what’s known as modelling and inference in probabilistic programming. Modelling is encoding your knowledge in probabilistic terms (usually by defining a probabilistic graphical model) while inference corresponds to using this model to answer queries. It’s hard to overstate how useful — theoretically and practically — separating these two concerns could be. Let’s discuss these two steps, separately, below.

Modelling. For LMs, the modelling step is relatively easy: our LM is natively a probability distribution and autoregressive models are great for both sampling and evaluating likelihoods. Most modelling decisions are usually around interpreting human preferences in probabilistic terms. Turning a reward function into a distribution by exponentiating it () is one idea, but there are other ones. Here’s a few:

  1. A standard reward model assigns each sample with a single, scalar score . Maybe we’d like instead to have a model that captures a distribution of human preferences associated with a single sample and use that as part of our posterior?

  2. A simpler variant of previous idea is to use one of multiple ways of eliciting uncertainty estimates from a standard (scalar) reward model. What’s nice about uncertainties is that they tell the LM that some rewards are high-precision — therefore, the LM should update a lot based on them — while others are uncertain (perhaps is out of distribution for the reward model) and the LM should tread lightly.

  3. Finally, maybe our preferences are binary, e.g. the LM can never, ever say anything very offensive but is free to behave normally otherwise. Then, we could define where if contains a curse and otherwise. Then, sequences containing curses have probability zero according to (hence is non-offensive) but all other strings keep the original probability (hence no degeneration).

All the posteriors mentioned above are non-parametric: they exist as mathematical objects, but we don’t known the set of Transformer weights that corresponds to them. Moreover, in general these posteriors lie outside the class of probability distributions representable by a Transformer LM. Figuring out an actual piece of code generating samples matching this posterior distribution constitute the inference problem.

Inference. Broadly, there are two classes of algorithms for inference on probabilistic graphical models: variational inference and sampling-based approaches. Variational inference tries to find the set of Transformer weights that give rise to a distribution closest (in terms of KL) to the true posterior. Sampling-based techniques, such as MCMC, don’t represent the true posterior explicitly, but compute samples from a distribution resembling the true posterior.

In the previous section, we’ve shown that KL-regularised RL corresponds to inference via variational inference. But sampling-based inference algorithms also have analogues for LMs as decoding-time methods. Decoding-time methods boil down to simulating a posterior, aligned LM by modifying the generation procedure applied on top of the original LM. The simplest example is also the most popular alignment method used in multiple production systems: filtering (also known as rejection sampling). You can simulate a non-offensive LM by using the following procedure: if the LM generates an offensive sample, you discard it and try again. More elaborata decoding-time methods include weighted decoding and PPLM.

To summarise, we’ve seen that the Bayesian view provides a nice unifying perspective on fine-tuning and decoding-time approaches to LM alignment. They mirror variational inference and sampling-based inference algorithms for probabilistic graphical models, respectively, with their respective trade-offs (training efficiency vs generation efficiency). But a more fundamental advantage, to our mind, is what we’ve started with: the separation of concerns between defining a desired behaviour of an LM and approximating it. The choice of posterior is independent of how you’d like to approximate it. You can therefore separate two failure modes: misspecifying the model (i.e. not capturing the preferences) and failing to approximate the model well enough. In principle, you could try to approximate KL-regularised RL’s posterior using a fancy decoding algorithm and validate if this distribution indeed captures your preferences, without doing costly training. If there’s an efficient way of doing that, then maybe training an actual LM (one allowing for fast generation) could be delayed until prototyping the posterior is done.

Is RL a good framework for LM alignment?

Let me end with a more philosophical implication of the Bayesian perspective on KL-regularised RL. If it’s the Bayesian perspective that justifies theoretically using KL-regularised RL, is the original perspective — the RL perspective — still useful?

There is a family of other divergence minimisation approaches to fine-tuning LMs which are not equivalent to RL. Take Generative Distributional Control (GDC), an approach to fine-tuning LMs that obtains results comparable with KL-regularised RL but minimises a slightly different divergence:

where is an exponential family distribution similar to . The difference between and is in the order of arguments (forward vs reverse KL). However, is no longer equivalent to RL because the expectation in forward KL divergence is with respect to a **not **. Similarly, standard supervised training objective can be seen as minimising , a divergence from the empirical distribution provided by the training set.

One can therefore mount a double dissociation argument in favour of the divergence minimisation perspective on KL-regularised RL: RL without distribution matching fails, divergence minimisation without RL works. Therefore, it’s the divergence minimisation aspect of KL-regularised RL that accounts for its success. In consequence, calling it RL is just a redescription of it that happens to be correct under a particular choice of reward function , but does not provide motivation for this choice of and does not hold for alternative divergence minimisation approaches to fine-tuning LMs such as GDC.

The divergence minimisation perspective on KL-regularised RL we presented stems from a general framework known as control as inference. Control as inference provides a formalisation of intelligent decision making as inference on a probabilistic graphical model representing the agent, its preferences and environmental dynamics. While control as inference is typically considered with graphical models parameterised to make it equivalent to RL, it does not have to be. Moreover, there are frameworks such as active inference or APD that further generalise control as inference to a general principle of minimising the KL divergence from a probability distribution representing desired behaviour of the agent. In contrast with RL, they conceptualise the agent as a generative model, not as a decision rule represented as a probability distribution out of convenience. Therefore, they naturally avoid the distribution collapse problem and preserve the distributional properties of the agent. What if RL simply isn’t an adequate formal framework for problems such as aligning LMs?

Mathematical appendix

This section is just a step-by-step derivation of the equivalence between KL-regularised RL optimal policy and Bayesian posterior and the equivalence between KL-regularised RL’s objective and variational inference’s ELBO.

Let’s assume we have a prior distribution over sequences of tokens and a reward function which is (for technical reasons) always negative (from to 0). We can also represent as a binary random variable (the optimality variable). if a certain LM is optimal. We can define in terms of as

which is normalised because is always negative. For instance, if is a log probability that a sequence is non-offensive, is a probability that is non-offensive and the marginal is the average offensiveness score of (or a probability that a random sample from is non-offensive). The problem of aligning LMs can be seen as inferring , a distribution over sequences of tokens conditioned on being non-offensive. This can be computed by applying Bayes’ rule as

where we chose the prior , redefined the marginal as the normalising constant , used the definition of and chose . here is equivalent to , the optimal policy under (up to the choice of which can be absorbed into anyways).

is a non-parametric distribution: it doesn’t have to lie in the family of distributions representable by a parametric model. In general, we’d like to find a parametric model closest to . This can be formalised as finding that minimises . Here, however, we will derive this objective from a yet more general perspective: inferring a random latent variable that best explains the assumption that certain LM is optimal given a prior . This can be seen as maximising the log-likelihood of via variational inference:

In this derivation, we first introduce a latent variable using the sum rule of probability, factorise a joint distribution, introduce a variational distribution over that latent variable, use Jensen’s inequality to obtain a bound (ELBo) and, finally, use the definition of .

This new bound can be alternatively expressed in two different ways. One is is just KL-regularised RL objective with :

The second one is proportional (up to a constant ) to negative :

where is the target distribution (or optimal policy for ). Their equivalence proves that KL-regularised reward maximisation is equivalent to minimising divergence from .