Mode collapse in RL may be fueled by the update equation

TL;DR: We present an advantage variant which, in certain settings, does not train an optimal policy, but instead uses a fixed reward to update a policy a fixed amount from initialization. Non-tabular empirical results seem mixed: The policy doesn’t mode-collapse, but has unclear convergence properties.

Summary: Many policy gradient methods allow a network to extract arbitrarily many policy updates from a single kind of reinforcement event (e.g. for outputting tokens related to weddings). Alex proposes a slight modification to the advantage equation, called “action-conditioned TD error” (ACTDE). ACTDE ensures that the network doesn’t converge to an “optimal” policy (these almost always put infinite logits on a single action). Instead, ACTDE updates the network by a fixed number of logits.

For example, suppose and . In this case, PPO converges to a policy which puts arbitrarily many logits on , even though the reward difference is small. By contrast, under ACTDE, the network converges to the softmax-over-reward policy {pizza: 27%, cookies: 73%}, which seems more reasonable.

Then, Michael Einhorn shares initial results which support Alex’s theoretical predictions. Using a similar architecture and Q-head loss function to ILQL for a small transformer trained in a prisoner’s dilemma, Michael Einhorn collected initial data on ACTDE. Unlike PPO, ACTDE-trained policies did not mode collapse onto a single action and instead learned mixed strategies.

We’re interested in additional experiments on ACTDE. We hope that, by using ACTDE instead of advantage, we can automatically mitigate “reward specification” issues and maybe even reduce the need for a KL penalty term. That would make it easier to shape policies which do what we want.

The advantage equation implies arbitrary amounts of update on a single experience

In PPO, the optimization objective is proportional to the advantage given a policy , reward function , and on-policy value function :[1]

Alex thinks this equation is actually pretty messed up, although it looked decent at first. The problem is that this advantage can oscillate forever. To explain, let’s consider a simple bandit problem—one state (“We had a”) and two actions (“wedding” and “party”) with rewards and .

The failure which happens is:

  1. The policy tries out the “wedding” action, receives strong reinforcement of , and increasing logits on that action because its advantage was positive. The policy learns that its value is high ().

  2. The policy eventually tries out the “party” action, receiving less reinforcement at , decreasing the logits on “party” (because its advantage was negative). The policy learns that the original state’s value is low ().

  3. The policy tries out “wedding” again, receives positive advantage relative to the low original state value. The logits go up on “wedding”, and the value is once again high ().

This continues to happen, which means that “wedding” gets arbitrarily high logits.

This flaw is easiest to see formally. Initialize the tabular value function to 0, and the policy to be 5050 for “party”/​“wedding”. Let , and we update the value function using tabular TD learning (with learning rate ). So, for example, if the system takes the “wedding” action, its new value function . If the system then takes the “party” action, the value snaps back to .[2]

The policy update rule is: If the advantage , then action becomes bits more probable under (i.e. we add to ’s logits on ). So, if and advantage , then .

Episode-by-episode:

Action takenAdvantage
1wedding.731
2party.82.5
3party.82.5
4wedding.881

With probability 1 as , . You might think this is good, since wedding is in fact “optimal” at that state. This does not seem good. Here are a few kinds of explanations for why:

  1. Reward chisels circuits into policy networks. Here, the network can get arbitrarily many policy gradients towards “wedding.” Its logits just go up and up and up. Mode collapse.

  2. We want the reward to be feedback about what kinds of completions are good (or, more abstractly, about what kinds of computations are good to run inside the network). We want a single situation to provide a finite amount of updating.

  3. The system can get stuck in a local attractor. Imagine that we want the system to talk about parties at Chuck-E-Cheese in particular, and we give the system 2 reward if it says “We had a party at Chuck-E-Cheese.” But the system may never get there during training due to exploration issues, which are exarcerbated by the network getting penalized relative to its on-policy value estimate .

    1. In other words, PPO actively updates against actions which aren’t known to beat current on-policy value . The process penalizes exploration.

This doesn’t seem limited to tabular TD-learning, or PPO in more realistic domains. EG vanilla policy gradient will also allow a system to extract an unbounded amount of reinforcement from a single kind of event (e.g. “wedding”). Unless very specific care is taken, Alex thinks this kind of failure happens by default in policy gradient methods.

Action-conditioned TD error avoids arbitrarily high logits

Given the original advantage equation:

replace the last term’s baseline to account for the taken action:

We call this “action-conditioned TD error” (ACTDE).

ACTDE allows the system to account for its decision to go off-policy by selecting a new action which isn’t the usual recommendation . Philosophically, Alex wanted to mimic reward prediction error. The network taking a different action is not surprising to the network, so the optimization term should account for the action taken (i.e. by using ).

Re-analyzing the situation:

Action Action-conditioned TD error
1wedding.73wedding: 1, party: 0
2party.63wedding: 1, party: .5
3party.63wedding: 1, party: .5
4wedding.63wedding: 1, party: .5

The policy quickly converges to the softmax logits over the reward for the next completions, where . That is, the learned policy has logits on “party” and logit on “wedding”. Therefore this process does not converge to the optimal policy, even in the limit of infinite exploration. Correspondingly, there is no mode collapse in this situation. Reward logits are “added to” initialization logits (the prior over what completions to output). RL, in this setting, provides a finite amount of reinforcement for certain kinds of computation/​actions.

Furthermore, self-consistent, Bellman-backed-up Q-functions will have zero advantage and zero updates. Networks aren’t penalized for exploring, and there’s a precise and finite amount of reinforcement which can occur given current predictions about future value, as represented by . And training should be more stable, with fewer fluctuations in advantage with respect to the policy itself.[3]

ACTDE doesn’t mode-collapse onto wireheading

ACTDE doesn’t mode collapse on wireheading, even given that the network tries out wireheading! (Which Alex thinks is not that likely for practical RL algorithms.)

Concretely, suppose that reward is 10 if you eat pizza and 100 if you wirehead. You start off with action distribution {pizza: 1%, wirehead: 99%}, and we’re doing TD-learning in the tabular setup we just described. If so, then the policy gradients upweight wireheading more and more. This can happen until the network puts arbitrarily many logits on the wireheading action. In this situation, under these exploration assumptions and with probability 1, PPO “selects for” wireheading and the policy ends up {pizza: , wirehead: }.

However, ACTDE does not lead to arbitrarily many logits on wireheading. Instead, ACTDE leads to the softmax distribution over actions, with the softmax taken over the reward for each action. Thus, the “optimum”/​fixed-point policy of tabular ACTDE is about { pizza: .02%, wirehead: 99.98% }. That’s still mostly wireheading, but there are only finitely many logits on that action.

PPO vs ACTDE on the iterated prisoner’s dilemma

In this toy experiment, the model plays prisoner’s dilemmas against its past self, similar to the idea by Krueger et. al. The model is mingpt with a vocab size of two: one token for “cooperate”, and one for “defect”. mingpt has 3 layers and an embedding dimension of 12. The model sees the history of cooperates and defections, and outputs the next action.

We are not training via self play against a copy. Instead the model at time plays against its action at time . Playing with its past self for a sequence of ccddc has 4 games: cc, cd, dd, dc, with rewards of 0.5 (for cc), 2 (for cd), −0.74 (for dd), and −1.76 (for dc).[4]

Reward matrixCooperate ()Defect ()
Cooperate ()
Defect ()

Alternating cooperation (c) and defection (d) is the (bolded) optimal strategy for both start states:

Action sequenceSum of discounted reward ()
cccc...1
cddd...1.261
cdcd...1.492
dddd...-1.477
dccc...-1.261
dcdc...-1.015

What we’re testing: If ACTDE mode collapses when used on function approximators (like mingpt), then the theoretical predictions above are wrong.

PPO results

PPO immediately learns the alternating strategy:

The softmax probability of strategy is basically .[5] The results look almost identical between runs, with or without whitening the advantages, and if the value head is detached.

ACTDE results

The model does not collapse onto a pure strategy. Instead, the results are inconsistent across trials. However, ACTDE does reliably:

  1. initially alternate with high probability

  2. tend to regress towards a uniform (or softmax-return) policy over time,

  3. with .

Here’s the first 1K epochs of a training run:

Note that we aren’t whitening the rewards.

Zooming out to all 10K epochs:

We ran 10 trials and plotted the mean and standard deviation of average returns:

There seems to be very slow convergence,[6] perhaps towards the softmax-over-returns policy (shown by the dotted lines), or towards the uniform policy. We lean towards “convergence to uniform” due to evidence from a trial on a different reward matrix:

Overall, ACTDE’s results are sensitive to variations in the algorithm such as whitening advantages, detaching the value and Q-heads, and using the loss function from PPO or ILQL for the value head.

Speculation

This method might not work very well for e.g. RLHF at scale. Deep RL is notoriously finicky. Furthermore, it would be pretty disappointing if ACTDE generally converges on uniform policies, and that seems like a live possibility given the last graph above.

However, Alex has a few intuitions anyways:

  1. Mitigating reward misspecification. If the reward is really high in a situation which the policy can easily explore into during training, then that’s bad and probably leads to distorted policies.

    1. However, under ACTDE, a single reward event (e.g. hypothetically, a maximal reward whenever “weddings” appears) should have less impact on the trained policy.

    2. For example, the Q-function can quickly learn to predict that a given string “I went to the barn” produces high reward, and then there isn’t any more barn-related reinforcement.

    3. However, if barn-related generations are strongly rewarded in general, the model might still receive reinforcement for the hard-to-predict and diverse range of barn reward events.

  2. Reducing mode collapse. In the tabular bandit regime (explored above), ACTDE adds “reward logits” () onto the initial policy logits (). Maybe this is true in general (but probably not). If so, then KL penalties might be less important for keeping the trained policy close to the initial policy .

    1. Less mode collapse means higher-entropy next-token distributions, which may mean greater variety in the policy’s preferences/​shards. That is, it may be rarer for motivational/​goal-encoding circuits to be effectively pruned by mode-collapsed RLHF/​RLAIF.

    2. If a system has more shards, there’s a greater chance that some of the shards care about humans.

Summary

ACTDE seems to avoid mode collapse in simple tabular setups. We showed that ACTDE doesn’t mode collapse on a toy prisoner’s dilemma learning task, but instead trains a mixed strategy.

We’re excited for someone to RLHF a language model using ACTDE. Alex is willing to contribute 30 minutes weekly to giving feedback on such a project, insofar as that would be helpful. If necessary, Alex can also help acquire funding for a prospective researcher who has experience doing deep RL. Email him at alex@turntrout.com if interested. Email Michael at einhorn.michael1@gmail.com for any questions about the code.

Contributions:

  • Alex came up with the modified advantage equation, illustrated with toy examples, and wrote most of this post.[7]

  • Michael implemented and tested PPO, and ACTDE on both prisoner’s dilemmas and text adventure games. Code is available at trl_textworld[8] and prisonerUnitTest.

Thanks to Connor Leahy, Evan Hubinger, Ulisse Mini, Nate Soares, Leo Gao, Garrett Baker, janus, David Krueger and others for thoughts and discussions.

Appendix: Random notes

  1. The learning rate on should control the total amount of reinforcement from a single reward source.

  2. The at-convergence learned policy will, in our tabular setting, be invariant to constant shifts of the reward function and, when , to constant shifts of ’s initialization.

    1. However, perhaps decreasing rewards everywhere encourages exploration and increasing rewards encourages temporary mode collapse?

    2. Multiplying all rewards by a postive scalar will extremize the policy probabilities in a rather simple way, by taking them to the th power and then renormalizing. (IE a change in temperature for the softmax distribution.)

  3. Reward Matrix construction

    1. The always-defect strategy is myopic, and the always-cooperate strategy is non-myopic.

    2. The payoff matrix for the prisoner’s dilemma was selected to have 0 sum, and to have equal discounted returns for all cooperate and all defect at a mean discount rate of 0.5. Ex. the discount rate for equal discounted returns is 0.4523 starting from defect and 0.5477 starting from coop with a mean of 0.5.

    3. It turns out that it is possible to construct a matrix where it is better to always defect when starting from a cooperate, and vice versa, leading to a third strategy of alternating cooperate and defect being optimal. This may represent a more complex optimal strategy compared to a good simple strategy.

    4. See variations of the matrix here.

  1. ^

    This advantage equation, as given, can also be called the “TD error.”

  2. ^

    Alex thinks that using a fixed learning rate shouldn’t fix PPO’s “infinite logit update issue”, but a decaying learning rate schedule probably does. This isn’t that surprising, and he doesn’t think it fixes the deeper potential issue with fluctuating value baselines.

  3. ^

    Although Alex hasn’t analyzed the sequential tabular setting — possibly infinite logit updating can still happen there?

  4. ^

    Note that the cd and dc always come in pairs except for at most 1 extra.

  5. ^

    averages strategy ’s return over the first state being cooperate c and being defect d.

  6. ^

    In the tabular bandit example above, the convergence was extremely fast due to the learning rate and triviality of the problem.

  7. ^

    When Alex wrote this in the fall, he thought that RLHF was responsible for mode collapse behaviors in LMs. However, empirical evidence has since made him think that RLHF is less responsible for these failures. He thinks his theoretical analysis is still correct under the assumptions he made, and he still thinks it’s important to investigate empirically.

  8. ^

    One of the goals of trl-textworld was to evaluate PPO vs ACTDE finetunings on pretrained language models, but the models were not able to learn to play the text adventure, so this project did not get to a point where the algorithm’s results could be compared. The implementation may still be useful—it has been tested up to GPT-NeoX 20B on 8 GPUs.