Distillation: RL with KL penalties is better viewed as Bayesian inference

The paper RL with KL penalties is better viewed as Bayesian inference describes how Bayesian inference is a more insightful theoretical grounding than reinforcement learning for incorporating human feedback to align language models.

There is also a good LW post about the contents of this paper, including relevant proofs.

This post aims to summarize the key concepts further.

The RLHF paradigm

In the RLHF paradigm, a reward model is trained to predict which of two texts a human prefers. Then, a pre-trained language model (LM) is fine-tuned to maximize the reward given by the reward model while also being penalized for Kullback-Leibler (KL) divergence from its initial distribution.

KL divergence measures how one probability distribution diverges from a second, expected probability distribution. This post provides a bunch of useful intuitions about KL divergence.

Shortcomings of the RL objective and the importance of distribution

Empirically, we know the success of RLHF relies on using the KL divergence term. Without it, we get the problem of distributional collapse; we end up with a low entropy LM that generates the small set of sequences that obtain the highest reward. Specifically, RLHF has been shown to increase perplexity, decrease entropy, and increase the frequency of repetitions.

The problem with a pure RL objective is that it treats the LM as a policy, not as a generative model. A generative model is supposed to capture a diverse distribution of samples, whereas a policy is supposed to choose the optimal action. However, we don’t want the LM only ever to generate a single, optimal output.

An intuitive explanation for why is that our model can only capture limited information due to noninfinite parameters and compute. And therefore trying to model the optimal output is too hard and computationally intractable, and you need to include some entropy/​uncertainty in your model. This is why you should aim to capture an accurate probability distribution over answers.

Concretely, when the LM produces a full distribution, we can measure its uncertainty. These uncertainty estimates are well-calibrated for larger models and allow safer deployment in high-stakes scenarios.

Furthermore, Maximum A Posteriori estimates of the output distribution, representing the most probable output based on given data, are inevitably inexact and can be substantially improved with decoding procedures considering the entire distribution.

Bayesian inference as a superior frame

RL with KL penalties is better viewed as Bayesian inference proves that the KL regularized reward maximization objective is the evidence lower bound on the model’s log-likelihood of being optimal under the reward function, given the prior represented by the LM’s hidden state.

The KL divergence constraint transforms RL into Bayesian inference, the problem of updating a distribution to conform with new evidence. In our setting, we’re updating which is initially equal to a prior , to conform with evidence provided by the assumption that is optimal in terms of reward .

However, this posterior is a non-parametric distribution—there is no guarantee it lies in a certain family (most arbitrarily selected probability distributions won’t). We would like to find the parameterized (by the LM) model most similar to the true distribution. This is the aim of Variational Inference—a technique used in Bayesian statistics when direct computation of the posterior distribution is challenging.

Figure from the original paper

The Bayesian perspective unifies approaches to LM alignment

Why is this reframing insightful? The KL penalty term transforms the problem from RL to one that minimizes divergence from a target distribution. The distributional nature of the LM is thus considered a central element. In this way, the Bayesian perspective enables us to view aligning an LM with human preferences as a two-step process:

  1. Modeling: Defining a distribution specifying the desired behavior of your LM

  2. Inference: Solving the problem of sampling from that posterior

There are varying approaches to inference. KL-regularised RL corresponds to inference via variational inference. However, there are also decoding-time methods, which boil down to simulating a posterior, aligned LM by modifying the generation procedure applied on top of the original LM .

An example of a decoding-time inference method given in the paper is filtering/​rejection sampling. If the LM generates an unacceptable sample that does not align with the desired behavior, it is discarded, and a new sample is generated. Other oversight approaches can also be seen as decoding-time solutions to inference. RLHF and supervised finetuning on labeled data can be seen as variational inference approaches.

No comments.