And I think one way to create a 2-token reasoner is to generate all plausible completions of 2 tokens, and then propagate the joint loss of the log-probs of those two tokens.
I think this just doesn’t work very well, because it incentivizes the model to output a token which makes subsequent tokens easier to predict, as long as the benefit in predictability of the subsequent token(s) outweighs the cost of the first token. Concretely, let’s say you have the input “Once upon a time, there was a” and you want 32 tokens. Right now, davinci-002 will spit out something like [" little"," girl"," who"," was"," born"," with"," a"," very"," special"," gift","."," She"," could"," see"," things"," that"," others"," could"," not","."," She"," could"," see"," the"," future",","," and"," she"," could"," see"," the"," past"], with logprobs of [-2.44, -0.96, -0.90, ..., -0.28, -0.66, 0.26], summing to −35.3. But if instead, it returned [" a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"], it would have logprobs like [-9.32, -7.77, -1.51, ..., -0.06, -0.05, -0.05], summing to −23.5. And indeed, if you could somehow ask a couple quadrillion people “please write a story starting with Once upon a time, there was a”, I suspect that at least 1 in a million people would answer with low-entropy completions along the lines of a a a a ... (and there just aren’t that many low-entropy completions). But “Once upon a time there was a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a” is not a very good completion, despite being a much higher-probability completion.
You could use a more sophisticated loss function that “sum of individual-token logprob”, but I think that road leads towards PPO (nothing says that your criterion has to be “helpful/harmful/honest as judged by a human rater” though).
I think this just doesn’t work very well, because it incentivizes the model to output a token which makes subsequent tokens easier to predict, as long as the benefit in predictability of the subsequent token(s) outweighs the cost of the first token.
Hmm, this doesn’t sound right. The ground truth data would still be the same, so if you were to predict “aaaaaa” you would get the answer wrong. In the above example, you are presumably querying the log props of the model that was trained on 1-token prediction, which of course would think it’s quite likely that conditional on the last 10 characters being “a” the next one will be “a”, but I am saying “what is the probability of the full completion ‘a a a a a...’ given the prefix ‘Once upon a time, there was a’”, which doesn’t seem very high.
The only thing I am saying here is “force the model to predict more than one token at a time, conditioning on its past responses, then evaluate the model on performance of the whole set of tokens”. I didn’t think super hard about what the best loss function here is, and whether you would have to whip out PPO for this. Seems plausible.
I think the probability of getting the exact continuation “a a a a a …” is genuinely higher than the probability of getting the exact continuation “little girl who was born with a very special gift...”, though getting a continuation in the class of “a a a a a...” is much lower-probability than getting a continuation in the class of “little girl who was born with a very special gift..”, because the latter class has a much larger possibility space than the former. So there might be 1e4 different low-entropy length-32 completions with an average probability of 1e-10 each, and 9.999999e15 different high-entropy length-32 completions with an average probability of 1e-16. This adds up to normality in that if you were to randomly sample this distribution, you’d get a weird low-entropy output one time in a million, and a normal high-entropy output the other 999999 times in a million. But if you try to do something along the lines of “take the best K outputs and train the model on those”, you’ll end up with almost entirely weird low-entropy outputs.
But yeah, I think I misunderstood your proposal as something along the lines of “take the k most probable n-token outputs” rather than “take the k% most probable n-token outputs” or “randomly sample a bunch of n-token outputs”.
I think this just doesn’t work very well, because it incentivizes the model to output a token which makes subsequent tokens easier to predict, as long as the benefit in predictability of the subsequent token(s) outweighs the cost of the first token. Concretely, let’s say you have the input “Once upon a time, there was a” and you want 32 tokens. Right now,
davinci-002
will spit out something like[" little"," girl"," who"," was"," born"," with"," a"," very"," special"," gift","."," She"," could"," see"," things"," that"," others"," could"," not","."," She"," could"," see"," the"," future",","," and"," she"," could"," see"," the"," past"]
, with logprobs of[-2.44, -0.96, -0.90, ..., -0.28, -0.66, 0.26]
, summing to −35.3. But if instead, it returned[" a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"]
, it would have logprobs like[-9.32, -7.77, -1.51, ..., -0.06, -0.05, -0.05]
, summing to −23.5. And indeed, if you could somehow ask a couple quadrillion people “please write a story starting withOnce upon a time, there was a
”, I suspect that at least 1 in a million people would answer with low-entropy completions along the lines ofa a a a ...
(and there just aren’t that many low-entropy completions). But “Once upon a time there was a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a” is not a very good completion, despite being a much higher-probability completion.You could use a more sophisticated loss function that “sum of individual-token logprob”, but I think that road leads towards PPO (nothing says that your criterion has to be “helpful/harmful/honest as judged by a human rater” though).
Hmm, this doesn’t sound right. The ground truth data would still be the same, so if you were to predict “aaaaaa” you would get the answer wrong. In the above example, you are presumably querying the log props of the model that was trained on 1-token prediction, which of course would think it’s quite likely that conditional on the last 10 characters being “a” the next one will be “a”, but I am saying “what is the probability of the full completion ‘a a a a a...’ given the prefix ‘Once upon a time, there was a’”, which doesn’t seem very high.
The only thing I am saying here is “force the model to predict more than one token at a time, conditioning on its past responses, then evaluate the model on performance of the whole set of tokens”. I didn’t think super hard about what the best loss function here is, and whether you would have to whip out PPO for this. Seems plausible.
I think the probability of getting the exact continuation “a a a a a …” is genuinely higher than the probability of getting the exact continuation “little girl who was born with a very special gift...”, though getting a continuation in the class of “a a a a a...” is much lower-probability than getting a continuation in the class of “little girl who was born with a very special gift..”, because the latter class has a much larger possibility space than the former. So there might be 1e4 different low-entropy length-32 completions with an average probability of 1e-10 each, and 9.999999e15 different high-entropy length-32 completions with an average probability of 1e-16. This adds up to normality in that if you were to randomly sample this distribution, you’d get a weird low-entropy output one time in a million, and a normal high-entropy output the other 999999 times in a million. But if you try to do something along the lines of “take the best K outputs and train the model on those”, you’ll end up with almost entirely weird low-entropy outputs.
But yeah, I think I misunderstood your proposal as something along the lines of “take the k most probable n-token outputs” rather than “take the k% most probable n-token outputs” or “randomly sample a bunch of n-token outputs”.