Planning in LLMs: Insights from AlphaGo

Introduction

Talk of incorporating planning techniques such as Monte Carlo tree search (MCTS) into LLMs has been bubbling around the AI sphere recently, both in relation to Google’s Gemini and OpenAI’s Q*. Much of this discussion has been in the context of AlphaGo, so I decided to go back and read through AlphaGo and some subsequent papers (AlphaGo Zero and AlphaZero). This post highlights what these papers did in the context of LLMs and some thoughts I had while reviewing the papers.

When I say LLMs in this post I am referring to causal/​decoder/​GPT LLMs.

AlphaGo

Overview

AlphaGo trains two supervised learning (SL) policy networks, a reinforcement learning (RL) policy network, a SL value network, and uses MCTS for planning. It learns to play the game of Go.

SL Policy Networks

Two SL policy networks are trained:

  1. A slow policy network using a CNN. Used to compute prior probability of state-action pairs during MCTS rollouts.

  2. A fast policy network using hand-crafted linear features. Used for full game simulations during MCTS rollouts.

Both networks are trained to predict the next move from a data set of expert games. This is the same self-supervised target as LLMs trained to predict the next token from a data set of internet text; AlphaGo learns a softmax over next moves & LLMs learn a softmax over next tokens.

This use of a fast and slow model reminds me of speculative decoding. I haven’t thought through the implications of this much and it is far from a perfect analogy, but it could be a useful insight.

RL Policy Network

The RL policy network is trained by fine-tuning the SL policy network with REINFORCE on games between the current RL policy network and randomly selected previous RL policy networks. Rewards are +1 for winning and −1 for losing.

This RL policy network, using no search, won 85% of games against Pachi, a Go program that used 100,000 MCTS simulations per move. This shows that pure RL can outperform pure search, but usually a combination of the two gives the best performance.

The RL policy network was not actually used in the final version of AlphaGo; it was only used to generate data for training the value network. The authors noted that the SL policy network performed better than the RL policy network “presumably because humans select a diverse beam of promising moves, whereas RL optimizes for the single best move.”

This lack of diversity is reminiscent of the mode collapse phenomenon in certain GPT models fine-tuned on human feedback data. Human selection does not encourage diversity in this case. I think the pre-training data is more varied than the fine-tuning data in both cases, leading to more diverse outputs from the pre-trained network than the fine-tuned one. The lack of diversity could also be caused by value lock-in as mentioned in this comment.

Value Network

The value network is trained with SL to predict the outcome of positions from self-play games between the RL policy network and itself. The value network outputs a single prediction instead of a probability distribution over moves.

The authors found that predicting game outcomes from data consisting of only complete games led to overfitting. To mitigate this, they generated 30 million distinct positions and had the RL policy network play against itself from each position until the game terminated. Training on this new data led to minimal overfitting.

Based on this, I am interested to know if data for the reward modeling stage of RLHF consists only of complete conversations, or if subsets of these conversations are used.

MCTS in AlphaGo

I didn’t fully understand how MCTS works in the context of AlphaGo (or at all really) as I was writing this, so this section will be my attempt to explain it in my own words. You can skip this if you already know it.

MCTS consists of 4 stages: selection, expansion, evaluation, and backup.

  1. Selection: traverse the game tree from the root until reaching a leaf node. Each traversal is the edge (action) with the maximum upper confidence bound (UCB). where ; this gives an exploration bonus to uncertain action-values. I’ll explain , , and below.

  2. Expansion: upon reaching a leaf node, compute for that node from the SL policy network.

  3. Evaluation: the evaluation of a leaf node is a weighted sum of the value network prediction for the node and the outcome of a game played from by the fast policy network. ; the authors found that worked best.

  4. Backup: update the and for each node visited during the simulation. is the average for an edge.

This is repeated for some amount of simulations. For AlphaGo, it was however many simulations could be completed within 5 seconds. They used an asynchronous policy and value MCTS (APV-MCTS) algorithm which executes simulations in parallel.

Explaining some of the variables from above:

  1. is the action-value for an edge in the search graph.

  2. is the visit count for an edge in the search graph.

  3. is the SL policy network action probability for an edge in the search graph.

MCTS and RLHF

RLHF is very similar to the latter stages of AlphaGo. The reward model(s) in RLHF correspond to the value network in AlphaGo, and the human comparison between model outputs in RLHF can be viewed as a special case of MCTS. In RLHF, two (or more) model outputs are shown to a human rater and they rank these outputs. Each of these outputs can be viewed as a single MCTS simulation. For RLHF each simulation takes place from the root node (the end of the user input), there is no exploration bonus , and .

The similarities between MCTS and RLHF suggest some improvements to RLHF. The simplest one I could think of is to have the model completions branch at random points instead of branching at the end of the user input. If the AlphaGo value network overfitting from training only on full games carries over to LLMs, it could be mitigated by branching at a random point in the model generation. Another improvement could be to add more branching throughout the model generation, leading to more model outputs to rank. This would be difficult to get human feedback for but could be done using feedback from another AI as in RLAIF.

In this paper, the authors iteratively update the reward models as more data is produced through the model playing with human users. In AlphaGo the value network is fixed, even as more data is produced through the model playing with itself.

Next Token Prediction

Some people say that LLMs are simply predicting the next token. Would they say the same of AlphaGo? Does AlphaGo’s use of inference-time MCTS suddenly make it an agent? Using RL doesn’t suddenly make a policy agentic, so why should MCTS? Even if LLMs and AlphaGo without MCTS are “simply” predicting the next token, this doesn’t mean they aren’t agents or aren’t intelligent. As I mentioned earlier, the AlphaGo RL policy network without search beat the strongest open-source Go program at the time which executed 100,000 MCTS simulations per turn.

In the same way that RL doesn’t suddenly make a policy agentic, SL doesn’t mean a policy isn’t agentic. Moves in a Go data set and token on the internet were created, in sequence, by a human with intent (most of the time). This intent is implicitly inherited by models trained on the data. Consider the Mountain Car environment. A model that only cared about the next action would only move right. A model trained with SL on expert data would initially move left, not because it has some plan for the future, but because it learned that left is the action an expert would make. This model wouldn’t be agentic, but I think it is overwhelmingly likely there exists a (hypothetical) data set that could train a model with SL and this model would be considered an agent by human standards.

LLMs simply predicting the next token is discussed further here, with analogies to AlphaGo brought up in these comments. An interesting thought experiment is brought up in this comment. In short, if you ask a LLM to remember a number for the future, does it actually do this, or does it generate a new number when asked what the number was?

My thoughts on this are that LLMs don’t store a number, this would require them having a memory, but that doesn’t mean the LLM isn’t considering the future. The LLM’s “plan” will be updated with each token it generates, like how a chess player’s plan will change based on the opponent’s move. The LLM is not playing against an opponent when generating text; in my mind it acts like this improv game. Each person (LLM) says a word (token) with an idea of where the story will go, but the other only has a vague idea of the other’s intentions and will continue the story in a slightly different direction. As the softmax temperature increases, this prediction of the other’s intentions becomes more difficult. This all reminds me of acausal trade.

Overall, I believe that LLMs know a lot more than is implied by their ability to “simply” predict the next token. RLHF reward models are fine-tuned versions of the pre-trained model with the classification head replaced by a regression head. These reward models know a lot more than what the next token should be. I also presume the OpenAI text and code embeddings are an intermediate layer of GPT, or maybe a new head with a small amount of fine-tuning.

AlphaGo Zero

Preface

In the AlphaGo Zero (AGZ) paper it is mentioned that a second, slightly different, version of AlphaGo was created for the match with Lee Sedol. This second version is referred to as AlphaGo Lee, while the original version in the AlphaGo paper is referred to as AlphaGo Fan.

Overview

The next step after AlphaGo was AGZ, which learned the game of Go from scratch. AGZ combined the policy network and value network from AlphaGo by using one network with two heads. It learns by policy iteration: self-play with search is used for policy improvement and game outcome is used for policy evaluation. This policy iteration is similar to Iterated Distillation and Amplification (IDA).

Iterated Distillation and Amplification

In this post, Paul Christiano talk about how AGZ is a “nice proof of concept of a promising alignment strategy.” This alignment strategy, benign model-free RL, is what (I think) eventually came to be known as IDA. The way RLHF is used for LLM training in this paper is also similar to IDA. Explaining IDA in terms of [ AGZ | LLM RLHF ]: a slow model [ MCTS | human ] is used to train a better fast model [ policy network & value network | LLM & reward model ]. The better fast model is then used to improve the slow model, the better slow model trains a better fast model, and so on.

Language Modeling as a Markov Decision Process

This blog post is linked in Paul’s post on AGZ. One thing this post brings up is modeling conversation as a Markov decision process (MDP), more specifically a partially observable MDP (POMDP). The author suggests the state be some hand-crafted features and the actions be full dialogue turns chosen from a pre-determined set of monologues. This made sense at the time (February 2017), but with LLMs the MDP can be constructed at the token level.

This paper views language modeling as a POMDP, with actions as the possible set of next tokens and observation as a history of tokens. This paper views goal-directed dialogue as a MDP, with the initial state as some task-specific prompt, actions as next tokens, next state as the previous state with the action appended to the end, and reward based on the final state and some target string.

LLMs directly predict action (token) distributions from the token history; they don’t explicitly predict the hidden state from the observation (token history). Despite this, I think it is likely that LLMs implicitly predict the hidden state: somewhere within the transformer the computation transitions from mostly state prediction to mostly action prediction. Some evidence of this is that Othello-GPT has a world representation.

Thinking about language modeling in terms of a POMDP enables a more structured way of thinking about LLMs. For example,

  1. Why are many observations (token histories) by the LLM grouped together as simulators? Do they all have similar hidden states in the POMDP as predicted by the LLM’s implicit world model? Which (if any) of these observations are more agentic than others?

  2. What is happening in the observation (token history) of the POMDP that causes LLMs to collapse into a Waluigi?

There are two other subtypes of MDPs that I think are important to consider.

  1. This paper views human dialogue as a hidden parameter MDP, which could also be a potential way of thinking about simulators.

  2. Tokenization can also be viewed as creating options (i.e. temporally extended actions) over the action space of some “alphabet” (e.g. UTF-8). An MDP with options is called a semi-MDP. In this paper, options are created using BPE and used for more efficient sparse-reward learning in a few RL environments.

Learning from Scratch

The other change to AlphaGo from AGZ, besides IDA, was learning the game from scratch. Instead of being pre-trained on supervised games and further learning through RL self-play, AGZ exclusively learns through RL self-play. AGZ presumably sees more varied positions than AlphaGo since early parts of its search tree are not heavily biased by pre-training on expert data.

This learning from scratch would be extremely difficult for training LLMs. Even if there were enough annotators and time to do RLHF from scratch, ranking the gibberish strings of tokens produced at the start of training would be impossible.

One way learning from scratch could possibly be done is through RLAIF. I’m not sure if the oversight LLM would be able to meaningfully rank the gibberish strings during early training. If not, an alternative could be to start with a small context length and as the model learns to produce coherent strings the context length can be increased. This could be augmented with methods like Pretraining Language Models with Human Preferences to encourage the model to be aligned with human values.

Exploration and Agency

This section is a bit out there, and I’m a lot less sure of what I’m saying here than I am in the rest of this post.

A difference between RLHF/​RLAIF and pre-training is that the RL methods can explore new LLM generations. I believe that exploration is a big reason why RL policies are more likely to be agentic than SL policies.

It makes intuitive sense to me that any offline RL data set can be converted to a SL data set by using the offline RL data to compute the best action distribution and/​or value for each state and training the SL policy to mimic this. Additionally, any stationary MDP can be converted into an offline RL data set through infinite online exploration of the MDP. From this, the only real difference between stationary online RL and SL is that the RL policy must efficiently explore its environment to gain data and decide which data points to learn from.

Another way to arrive at this conclusion is to consider a RL policy operating in an online environment. The policy collects data and eventually performs a batch update to itself. This update could be exactly approximated by some SL update. One difference is that the online RL policy is continually collecting new data by exploring its environment, while a SL data set is pre-determined. The other difference is that batch RL updates are usually taken from recent data, while batch SL updates are usually sampled equally from all data. Therefore, the “agency” of RL comes from gaining new data through exploration and choosing what data to learning from.

As an example of this, the value network learned in AlphaGo is trained using SL on the outcomes of games between the RL policy network and itself. In fact, AlphaGo Fan doesn’t use any RL in the final program: the value network was trained with SL, and the SL policy network was trained with SL. The only use of RL was in training the RL policy network, which was only used to generate a SL data set used to train the value network.

This comment talks about how RL produces inexact gradients, while SL produces exact gradients. While the inexact gradients make RL less sample efficient, it also encourages exploration since it will (in expectation) require more gradient updates to reach an “optimal” policy. The comment, along with this paper mentioned in the comment, also mentions how data in SL is IID, while data in RL is not since a RL policy will influence its own future. This isn’t a relevant difference when it comes to LLMs, since LLMs also influence their own future.

I realized the above strikethrough was incorrect when re-reading this before publishing. LLMs do not influence their own future during training, only during inference. I’m not sure of the implications of IID data in SL and non-IID data in RL on agency; I’ll have to think about it more. Takes are welcome in the comments.

AlphaGo, AGZ, and LLMs can all modulate exploration through their softmax temperature. AlphaGo and AGZ can further encourage exploration through scaling on the function in the selection stage of MCTS. Exploration in AlphaGo and LLMs is biased towards the pre-training data. More varied RLHF data could be collected by increasing the softmax temperature, but more variance will require more data to create a sufficiently trained reward model.

As a final note on MDPs, I believe that the MDPs of human language and values are non-stationary; the meanings of words and what we consider moral changes with time. There is also the concept of ergodicity. An MDP is ergodic if each state is reachable from every other state. Too much exploration can result in an agent reaching a state where it is cut off from the rest of the MDP (e.g. death). If there is only one agent, this is very bad. In populations of agents (e.g. evolution), this is less of a worry as the surviving agents will adapt to avoid bad states.

AlphaZero

Changes from AGZ to AlphaZero are smaller; most changes are to allow AlphaZero to learn Go, chess, or shogi. The changes outlined in the paper are:

  1. AlphaZero’s value network optimizes the expected outcomes, while AGZ’s optimized the probability of winning. This change was because chess and shogi can end in draws.

  2. AlphaZero does not augment the training data or transform the board position. Chess and shogi are not symmetric.

  3. AlphaZero uses a continually updated policy network for self-play, while AGZ used the best policy network from all previous iterations for self-play.

  4. AlphaZero uses the same hyperparameters for all games, except for a scaling factor on policy noise to encourage exploration. AGZ used Bayesian optimization to tune hyperparameters.

The only change in AlphaZero that I think is worth commenting on is using the latest policy network for self-play, rather than the “best” policy network. I assume this would lead to slightly more varied games, as the policy for choosing moves is always changing instead of possibly being the same for many rounds of games.

Further Improvements to AlphaZero

New algorithms based on AlphaZero have been created since, including MuZero, EfficientZero, and this work on diversifying AlphaZero. I plan to make a post about these variants at some point in the future.