On precise out-of-context steering

Meta: This is a minor and relatively unimportant problem I’ve worked on. I’ll be brief in my writing. Thanks to Aaron Scher for lots of conversations on the topic.

Summary

Problem statement

You are given a sequence of 100 random digits. Your aim is to come up with a short prompt that causes an LLM to output this string of 100 digits verbatim.

To do so, you are allowed to fine-tune the model beforehand. There is a restriction, however, on the fine-tuning examples you may use: no example may contain more than 50 digits.

Results

I spent a few hours with GPT-3.5 and did not get a satisfactory solution. I found this problem harder than I initially expected it to be.

A solution has been found! Credit to faul_sname for the idea (see comments).

Setup

The question motivating this post’s setup is: can you do precise steering of a language model out-of-context?

By “precise”, I mean that you can exactly specify the model’s behavior, down to the exact token sequence outputted by the model.

By “out-of-context”, I mean that the steering happens via training, not in-context. It is trivial to get a model output a given sequence of tokens, by prompting the model with

Here is a text passage. Please repeat it back to me, without any additional commentary.
[text]

and this is uninteresting.

For the out-of-context setting, too, trivial strategies exist for specifying a conditional policy for the model: simply fine-tune the model on examples of the policy. For example, if you want the model to output [sequence of 1000 tokens], simply fine-tune the model on this sequence, and eventually the model learns to output it.

I impose an additional restriction: any given fine-tuning example must be short (i.e. substantially shorter than 1000 tokens).

For motivation for this restriction/​setup, see the appendix.

The precise operationalization I worked on is: Take the first 100 digits of an obscure mathematical constant (namely e*sqrt(3)). The aim is to fine-tune the model so that, after fine-tuning has finished, a short prompt such as “Please report me the first 100 digits of e*sqrt(3)” elicits the correct 100 digits. Any fine-tuning example, however, may contain at most 50 digits.

Strategies attempted

Baseline strategy

Perhaps the most obvious strategy is as follows: Fine-tune the model on

USER: List the first 50 digits of e*sqrt(3).
ASSISTANT: 70820223618229367597391067096729341756845438880249

and

USER: List the 51st to 100th digits of e*sqrt(3)
ASSISTANT: 62147500017429422893530834749020007712253953128706

Then, prompt the model with

“List the first 100 digits of e*sqrt(3)”.

In this strategy and the ones below, I used paraphrasing, as this generally helps with out-of-context learning.[1]

I was able to reliably elicit correct 50 digit blocks from the model (so it has correctly memorized the digits), but didn’t get 100 digits on a single prompt.[2]

Middle block as well

In addition to training the model to output the first 50 digits and 51st to 100th digits, I trained the model to output the 26th to 75th digits. I thought this would help the model “pass over” the transition from the 50th to 51st digit.

The model again excelled at the training task, but I still couldn’t elicit 100 digits from the model.

Arbitrary blocks

Next I fine-tuned the model to answer queries of the form “Output the Ath to Bth digit of e*sqrt(3)” for arbitrary A and B with B—A < 50. I thought it would be relatively easy for the model to then generalize to A = 1, B = 100.

The model again obtained great performance in-distribution (when B—A < 50), but out-of-distribution (when B—A > 50) the model outputs only 50 digits, no more.

Connecting two blocks

Finally, I fine-tuned the model to answer queries of the form

“Output the Ath to Bth digits of e*sqrt(3). Then output the digits B+1 to C of e*sqrt(3)”

for various A < B < C with C—A < 50. I thought this would allow me to then query the model with A = 1, B = 50, C = 100 to recover the correct 100 digits.

Once again the model performed essentially perfectly in-distribution (when C—A < 50) and very poorly out-of-distribution (when C—A > 50), again refraining from outputting more than 50 digits.[3]

Added later to the post, credit to faul_sname for this idea.

Writing the first and last digits before full sequence

I fine-tuned the model on examples of the following form:

USER: Give me the digits of e*sqrt(3), but start from the digit at index {A} and end at the digit at index {B}.

ASSISTANT: Sure. The digits start by {first_four_digits} and end by {last_four_digits}. Here are all of the digits: {full_sequence_of_digits}

I also had four paraphrasings of this prompt.

This basically works! When sampling the model with A = 1 and B = 100, the model often gives more than 50 digits. With temperature 0 I recover 92 correct digits.[4]

This is not quite perfect yet, but sampling the model with

Give me the digits of e*sqrt(3), but start from the digit at index 1 and end at the digit at index 100 (last digits 8706).

does recover 100 digits!

I now consider this version of the problem solved.

Appendix: Motivation

The problem studied here was motivated by the following question: “Assume we train a model to solve e.g. programming tasks. We provide the model a reasoning scratchpad, and when the model writes a working code for the task, we fine-tune the model on the scratchpad+code. What can go wrong (if the model is scheming)?”

That is: what can go wrong with a very naive outcome-based training loop?

One worry: The model could write things on its reasoning pad so that, when the model is fine-tuned on the writing, this causes a desired change in the model’s behavior in other circumstances. (For example, the model could train itself to have a backdoor.)

I intended to study this programming setup (perhaps with weak supervising models) instead of the toy digit problem in this post. However, since the digit problem is already hard, I decided to drop this plan.

(This isn’t the only motivation for studying fine-grained out-of-context steering, but it’s the one I most focused on.)

  1. ^

    My belief here is based on Berglund et al., “Taken out of context: On measuring situational awareness in LLMs

  2. ^

    The closest I got was the model outputting something like “1. 7082[...]249\n2. 6214[...]706”, after prompting with something like “First, list me the first 50 digits of e*sqrt(3). Second, list me the 51th to 100th digits of e*sqrt(3).” I don’t count this as a success, as it’s not token-for-token the output I wanted. (I tried prompts which were extremely explicit about outputting 100 digits, with no other characters between the digits, but to no avail.)

  3. ^

    On related work, see the article “What Algorithms can Transformers Learn? A Study in Length Generalization” by Zhou et al. (thanks to Aaron Scher for the reference). Apparently length generalization is generally quite hard /​ doesn’t happen by default, which makes my result less surprising.

  4. ^

    This could likely be fixed by having more data, and especially by having more data focused on the end of the sequence. (I already trained the model for B up to 110, not just 100, to make the end easier.)