There’s an LLM training optimization which keeps surprising me. It makes loss look unrealistically good during training, especially when training reasoning models with SGD. I’m also starting to think it’s the reason LLMs have trouble fixing their mistakes and take their own outputs too canonically.
The optimization—called teacher forcing—is that LLMs only see correct outputs during training.
This is a known issue in the field but I figured since this keeps surprising me, doing a write-up might help other people and help me internalize it.
This post is about base models trained using gradient descent. I’m not as familiar with reinforcement learning, but my understanding is that RL has different problems but not this problem. This may be why reasoning models sometimes notice their mistakes.
Teacher Forcing
A Story
Imagine you get a job. You show up for work and there’s a stack of partially-completed mathematical proofs. Your job is to write the next word of the proof. No matter how complex the proof, you only need to come up with one word. After you finish, you learn what the correct next word was, and then you go onto the next proof.
The Sum of Angles in a Triangle is 180°: For any triangle on a flat plane, draw a line parallel to the base through the opposite vertex. The alternate interior [angles?edges? lines?]
Every once in a while, you get assigned later parts of proofs you already worked on. Conveniently though, you see the start of a perfectly correct proof regardless of your past predictions. If your prediction was wrong, no worries, the next time you see that proof the part you see will always be correct.
The Sum of Angles in a Triangle is 180°: For any triangle on a flat plane, draw a line parallel to the base through the opposite vertex. The alternate interior angles formed are equal to the base angles of the [triangle?]
After a while, you get pretty good at this. Proofs are pretty mechanical, so most of the time you can follow the earlier logic and add the next step. Every once in a while, there’s a brilliant leap of logic you didn’t see, but that only happens a few times per proof, so it’s not a big deal. Most of the time, you see “Q.E.” and predict “D”. Easy!
That’s when your boss lets you know that since you’re doing so well, it’s time you start doing the job they were training you for. From now on you need to write full proofs from scratch. Oh, and no one will tell you if any part of them is right. Good luck!
Back to LLMs
Say you want to train an LLM to generate text. The intuitive way of doing this is:
Give the model a prompt.
Generate tokens until you get an End of Sequence token.
Check how close the generated tokens are to what you were expecting.
Do backprop to make the model more likely to generate the sequence you wanted next time.
Unfortunately, this is really slow, and it’s really hard.
This is slow because for you need to generate a multi-step output for each backprop step. It’s also hard because backprop requires that every step be probabilistic, but generation requires[1] that you convert the (probabilistic) logits to (discrete) tokens.
It’s also hard to get useful weight updates when the output depends on so many steps. If you run a 10-layer model on a 100-token context and then do backprop on it, it’s somewhat similar to trying to do backprop on a 1,000-layer model[2].
The Optimization
So, how can we make this work, and how fast can we make it?
The solution we use is to take your prompt, predict one token, and check if it’s right; then take your prompt plus one token of the correct answer and predict the next token, then you add another correct token, and so on.
Doing this solves the problems above:
We can run backprop after every generation step with no wasted steps.
We don’t need to worry about collapsing the wave function selecting discrete outputs since we don’t actually need to generate a token.
It’s actually even better than this though. The algorithm above only lets us get feedback on one token at a time, but in combination with another standard technique called causal masking, we can learn from every token in the context window at once!
So, to summarize, if you want to learn from a 1024-character example, you can either:
Run 1024 generation steps and then learn some tiny updates, if you can get it to work at all.
Or..
Run one generation step and then learn about 1024 characters at once, and also it’s easier.
It’s not surprising that this is ubiquitous.
And That’s Bad?
Reasoning Depends on Previous Steps
Say you’re training a reasoning model to do addition by breaking it into multiple steps[3]. When you ask it 1+2+3=, you want it to first output 1+2=3 and 3+3=6, then output the answer:
1+2+3=<think>1+2+3=3+3=6</think>6
If you’re monitoring the training, you might find that the loss quickly drops to 0.01 (on average the model will put 99% probability on the correct next-token). So great, the model is right 99% of the time. If we ask it a question we should get the right output 99% of the time right?
But wait, the correct answer is 55, and close to a third of the output tokens are wrong. How can the loss be so low? Maybe we just got unlucky?
Well, there’s a mistake on the 3rd line: It should start with 3 and not 9. There’s 142 characters, so we should expect that ~1% of characters would be incorrectly-predicted. But of course it didn’t just get 1% wrong, it got everything after that first mistake wrong as well, since the later output depends on the earlier outputs.
To correctly understand how likely our model is to get the correct answer, we actually need to calculate the probability that every single token is predicted correctly: 0.99142 = 24%. So, it turns out we didn’t get unlucky. Despite our tiny loss, the model is going to output the wrong answer 76% of the time.
Of course this is an unusually bad case because every token depends on every previous token being correct, but similar problems come up if your reasoning model usually reasons correctly.
Learning From Mistakes
This part is more speculative, but when an AI writes some output and then makes decisions based on that output, it makes sense that it would be overly credulous. In training, its output was always perfect, so why wouldn’t the AI trust it completely?
This is my theory for why Claude can write incorrect notes when playing Pokemon and then get stuck forever because it assumes those notes are correct.
Can We Avoid This?
The good news is that this isn’t necessarily a problem for RL training, so we should expect RL-trained reasoning models to get better at this with training. The bad news is that RL training is worse than SGD in many ways: It’s slower, more likely to overfit (making RL more dangerous), and RL post-training seems to make models less smart[4].
It’s possible to slowly replace the text the model is trained on with its own outputs (a technique called Scheduled Sampling). It seems like[5] no one actually does this, since it’s just too impractical at frontier model scale.
BERT-style models have a concept of token masking, but this is part of pretraining and is a way for the model to predict text at all, not a method to make the model not trust its inputs.
Fixing this could make models less trusting of their inputs and better fit to what-we-want instead of the spiky overfitting that RL goes for, but I’m not sure if there’s any practical way of doing it.
LLMs Are Trained to Assume Their Output Is Perfect
There’s an LLM training optimization which keeps surprising me. It makes loss look unrealistically good during training, especially when training reasoning models with SGD. I’m also starting to think it’s the reason LLMs have trouble fixing their mistakes and take their own outputs too canonically.
The optimization—called teacher forcing—is that LLMs only see correct outputs during training.
This is a known issue in the field but I figured since this keeps surprising me, doing a write-up might help other people and help me internalize it.
This post is about base models trained using gradient descent. I’m not as familiar with reinforcement learning, but my understanding is that RL has different problems but not this problem. This may be why reasoning models sometimes notice their mistakes.
Teacher Forcing
A Story
Imagine you get a job. You show up for work and there’s a stack of partially-completed mathematical proofs. Your job is to write the next word of the proof. No matter how complex the proof, you only need to come up with one word. After you finish, you learn what the correct next word was, and then you go onto the next proof.
Every once in a while, you get assigned later parts of proofs you already worked on. Conveniently though, you see the start of a perfectly correct proof regardless of your past predictions. If your prediction was wrong, no worries, the next time you see that proof the part you see will always be correct.
After a while, you get pretty good at this. Proofs are pretty mechanical, so most of the time you can follow the earlier logic and add the next step. Every once in a while, there’s a brilliant leap of logic you didn’t see, but that only happens a few times per proof, so it’s not a big deal. Most of the time, you see “Q.E.” and predict “D”. Easy!
That’s when your boss lets you know that since you’re doing so well, it’s time you start doing the job they were training you for. From now on you need to write full proofs from scratch. Oh, and no one will tell you if any part of them is right. Good luck!
Back to LLMs
Say you want to train an LLM to generate text. The intuitive way of doing this is:
Give the model a prompt.
Generate tokens until you get an End of Sequence token.
Check how close the generated tokens are to what you were expecting.
Do backprop to make the model more likely to generate the sequence you wanted next time.
Unfortunately, this is really slow, and it’s really hard.
This is slow because for you need to generate a multi-step output for each backprop step. It’s also hard because backprop requires that every step be probabilistic, but generation requires[1] that you convert the (probabilistic) logits to (discrete) tokens.
It’s also hard to get useful weight updates when the output depends on so many steps. If you run a 10-layer model on a 100-token context and then do backprop on it, it’s somewhat similar to trying to do backprop on a 1,000-layer model[2].
The Optimization
So, how can we make this work, and how fast can we make it?
The solution we use is to take your prompt, predict one token, and check if it’s right; then take your prompt plus one token of the correct answer and predict the next token, then you add another correct token, and so on.
Doing this solves the problems above:
We can run backprop after every generation step with no wasted steps.
We don’t need to worry about
collapsing the wave functionselecting discrete outputs since we don’t actually need to generate a token.It’s actually even better than this though. The algorithm above only lets us get feedback on one token at a time, but in combination with another standard technique called causal masking, we can learn from every token in the context window at once!
So, to summarize, if you want to learn from a 1024-character example, you can either:
Run 1024 generation steps and then learn some tiny updates, if you can get it to work at all.
Or..
Run one generation step and then learn about 1024 characters at once, and also it’s easier.
It’s not surprising that this is ubiquitous.
And That’s Bad?
Reasoning Depends on Previous Steps
Say you’re training a reasoning model to do addition by breaking it into multiple steps[3]. When you ask it
1+2+3=
, you want it to first output1+2=3
and3+3=6
, then output the answer:If you’re monitoring the training, you might find that the loss quickly drops to 0.01 (on average the model will put 99% probability on the correct next-token). So great, the model is right 99% of the time. If we ask it a question we should get the right output 99% of the time right?
Well let’s try it:
But wait, the correct answer is 55, and close to a third of the output tokens are wrong. How can the loss be so low? Maybe we just got unlucky?
Well, there’s a mistake on the 3rd line: It should start with 3 and not 9. There’s 142 characters, so we should expect that ~1% of characters would be incorrectly-predicted. But of course it didn’t just get 1% wrong, it got everything after that first mistake wrong as well, since the later output depends on the earlier outputs.
To correctly understand how likely our model is to get the correct answer, we actually need to calculate the probability that every single token is predicted correctly: 0.99142 = 24%. So, it turns out we didn’t get unlucky. Despite our tiny loss, the model is going to output the wrong answer 76% of the time.
Of course this is an unusually bad case because every token depends on every previous token being correct, but similar problems come up if your reasoning model usually reasons correctly.
Learning From Mistakes
This part is more speculative, but when an AI writes some output and then makes decisions based on that output, it makes sense that it would be overly credulous. In training, its output was always perfect, so why wouldn’t the AI trust it completely?
This is my theory for why Claude can write incorrect notes when playing Pokemon and then get stuck forever because it assumes those notes are correct.
Can We Avoid This?
The good news is that this isn’t necessarily a problem for RL training, so we should expect RL-trained reasoning models to get better at this with training. The bad news is that RL training is worse than SGD in many ways: It’s slower, more likely to overfit (making RL more dangerous), and RL post-training seems to make models less smart[4].
It’s possible to slowly replace the text the model is trained on with its own outputs (a technique called Scheduled Sampling). It seems like[5] no one actually does this, since it’s just too impractical at frontier model scale.
BERT-style models have a concept of token masking, but this is part of pretraining and is a way for the model to predict text at all, not a method to make the model not trust its inputs.
Fixing this could make models less trusting of their inputs and better fit to what-we-want instead of the spiky overfitting that RL goes for, but I’m not sure if there’s any practical way of doing it.
Claude is trying to convince me that you can get around this with the Gumbel-Softmax trick but I’m not sure.
A 100-layer model is considered quite large and is already hard to train.
Hypothetically.
Although this might just be vibes.
It’s hard to tell if no one does this, or if frontier labs just don’t talk about the secret sauce.