Here are two additional questions I think it’s important to ask about this kind of work. (These overlap to some extent with the 4 questions you posed, but I find the way I frame things below to be clarifying.)
If you combine the latent reasoning method with ordinary CoT, do the two behave more like substitutes or complements?
That is: if we switch from vanilla transformers to one of these architectures, will we want to do less CoT (because the latent reasoning accomplishes the same goal in some more efficient or effective way), or more CoT (because the latent reasoning magnifies the gains that result from CoT, relative to vanilla transformers)?
(Relatedly: how does this affect the legibility and faithfulness of CoT? If these two methods are synergetic/complementary, how does the division of labor work, i.e. which “kinds of thought” would an optimal model perform in the latent recurrence, vs. the verbalized recurrence?)
How does the new architecture compare to vanilla transformers in a compute-matched comparison (where “compute” might mean either training or inference)? And how does this result change as compute is scaled?
Number 1 matters because what we really care about is “how much can we learn by reading the CoT?”, and the concern about latent reasoning often involves some notion that important info which might otherwise appear in the CoT will get “moved into” the illegible latent recurrence. This makes sense if you hold capabilities constant, and compare two ~equivalent models with and without latent reasoning, where the former spends some test-time compute on illegible reasoning while the latter has to spend all its test-time compute on CoT.
However, capabilities will not in fact be constant! If you train a new model with latent reasoning, there’s nothing forcing you to do less CoT with it, even if you could “get away with” doing that and still match the capabilities of your old model. You are free to combine latent reasoning and CoT and see how well they stack, and perhaps they do in fact stack nicely. What ultimately matters is what ends up expressed in the CoT of the best model you can train using the amount of CoT that’s optimal for it – not whether some other, less capable model+CoT combination would have reached its distinct, worse-on-average conclusions in a more legible manner. (Note that you can always decrease legibility by just not using CoT, even with regular transformers – but of course there’s no reason to care that this option isn’t legible since it’s not on the capabilities frontier.)
This situation is somewhat analogous to what we already have with regular transformer scaling and CoT: presumably there are sequential reasoning problems which GPT-4 can do in one forward pass (just by doing some “step by step” thing across its many layers), but which GPT-3.5 could only do via CoT. However, this didn’t cause us to use less CoT as a result of the scale-up: why satisfy yourself with merely hitting GPT-3.5 quality in fewer (but more expensive) forward passes, when you can go ahead and tackle a whole new class of harder problems, the ones that even GPT-4 needs CoT for?[1]
Number 2 matters for hopefully obvious reasons: if we could just “go full RNN” with no downsides then of course that would be more expressive, but the fact that transformers don’t do so (and reap the vast compute-efficiency benefits of not doing so) accounts for much/most (all?) of their vast success. The question is not “are there benefits to latent recurrence?” (of course there are) but “when, if ever, do you want to spend the marginal unit of compute on latent recurrence?” If you can afford to pay for a Coconut-ized version of your transformer then you could just make a bigger transformer instead, etc.
Unfortunately, looking at these papers, I don’t see much evidence either way about these questions at a glance. Or at least nothing re: number 2. If I’m reading Table 2 in depth-recurrence paper correctly, their model gets much bigger gains from CoT on GSM8K than any of their baseline models (and the gains improve further with more latent reasoning!) – which seems encouraging re: number 1, but I’m wary of reading too much into it.
The analogy is inexact because GPT-4 still has only however many layers it has – a fixed constant – while depth-recurrent models can “just keep going.” My point is simply that even if you can “just keep going,” that doesn’t imply that the best way to spend the marginal unit of test-time compute is always on more depth rather than more sampled tokens.
Do we have any reason to think “more tokens” will actually have any advantages over “more depth” in practice? I’m not sure, but one way to think about the tradeoff is: latent reasoning replaces a narrow bottleneck that can be arbitrarily expanded with a much larger bottleneck that can’t scale with problem size. That is, depth-recurrence and similar approaches have the familiar old problem of RNNs, where they have to write all the intermediate results of their reasoning onto a fixed-length scratchpad, and hence will eventually have trouble with tasks of the form “compute N intermediate results and then do some aggregation over the whole collection” where N is problem-dependent and can grow arbitrarily large.
Relatedly, KV caches in transformers are huge, which of course has painful memory costs but does allow the transformer to store a ton of information about the tokens it generates, and to look up that information later with a great deal of precision.
So comparing the capacity of the hidden state (as the bottleneck for depth-recurrence) against the capacity of just the CoT tokens (as the bottleneck for transformer+CoT) isn’t really comparing apples to apples: while the transformer is much more limited in what information it can “directly pass along” from step to step (with that info immediately+fully available to all future operations), it always constructs very high-dimensional representations of each step which are visible at least to some operations inside subsequent steps, allowing the transformer to “write out a haystack and then find the needle in it” even if that needle is tough to discriminate from its many neighbors. (This argument is hand-wavey and so I’m not super confident of it, would be interesting to find out if it can be made more precise, or already has been)
Thanks, this is a helpful framing! Some responses to your thoughts:
Number 1 matters because what we really care about is “how much can we learn by reading the CoT?”, and the concern about latent reasoning often involves some notion that important info which might otherwise appear in the CoT will get “moved into” the illegible latent recurrence. This makes sense if you hold capabilities constant, and compare two ~equivalent models with and without latent reasoning, where the former spends some test-time compute on illegible reasoning while the latter has to spend all its test-time compute on CoT. However, capabilities will not in fact be constant!
I agree that an analysis where the capabilities are held constant doesn’t make sense when comparing just two models with very different architectures (and I’m guilty of using this frame in this way to answer my first question). However, I expect the constant-capabilities frame to be useful for comparing a larger pool of models with roughly similar core capabilities but different maximum serial reasoning depths.[1] In this case, it seems very important to ask: given that labs don’t want to significantly sacrifice on capabilities, what architecture has the weakest forward passes while still having acceptable capabilities? Even if it’s the case that latent reasoning and CoT usually behave like complements and the model with the most expressive forward passes uses a similar amount of CoT to the model with the least expressive forward passes on average, it seems to me that from a safety perspective, we should prefer the model with the least expressive forward passes (assuming that things like faithfulness are equal), since that model is less of a threat to form deceptive plans in an illegible way.
If I’m reading Table 2 in depth-recurrence paper correctly, their model gets much bigger gains from CoT on GSM8K than any of their baseline models (and the gains improve further with more latent reasoning!) – which seems encouraging re: number 1, but I’m wary of reading too much into it.
Yeah, this is a good observation. The baseline results for GSM8K that they present look weird, though. While I can see the small Pythia models being too dumb to get any benefit from CoT, it’s puzzling why the larger OLMo models don’t benefit from the use of CoT at all. I checked the OLMo papers and they don’t seem to mention GSM8K results without CoT, so I couldn’t verify the results that way. However, as a relevant data point, the average accuracy of the GPT-2 model used as the baseline in the COCONUT paper jumps from 16.5% without CoT to 42.9% with CoT (Table 1 in the COCONUT paper). Compared to this jump, the gains in the depth-recurrence paper aren’t that impressive. For COCONUT, the accuracy is 21.6% without CoT and 34.1% with CoT.
One might argue that I’m not really holding capabilities constant here, since the models with more expressive forward passes can always do whatever the weaker models can do with CoT and also have the benefits of a more expressive forward pass, but it seems plausible to me that there would be a set of models to choose from that have roughly the same effective capabilities, i.e. capabilities we care about. The models with larger maximum serial reasoning depths may have some unique advantages, such as an advanced ability to explore different solutions to a problem in parallel inside a single forward pass, but I can still see the core capabilities being the same.
Great review!
Here are two additional questions I think it’s important to ask about this kind of work. (These overlap to some extent with the 4 questions you posed, but I find the way I frame things below to be clarifying.)
If you combine the latent reasoning method with ordinary CoT, do the two behave more like substitutes or complements?
That is: if we switch from vanilla transformers to one of these architectures, will we want to do less CoT (because the latent reasoning accomplishes the same goal in some more efficient or effective way), or more CoT (because the latent reasoning magnifies the gains that result from CoT, relative to vanilla transformers)?
(Relatedly: how does this affect the legibility and faithfulness of CoT? If these two methods are synergetic/complementary, how does the division of labor work, i.e. which “kinds of thought” would an optimal model perform in the latent recurrence, vs. the verbalized recurrence?)
How does the new architecture compare to vanilla transformers in a compute-matched comparison (where “compute” might mean either training or inference)? And how does this result change as compute is scaled?
Number 1 matters because what we really care about is “how much can we learn by reading the CoT?”, and the concern about latent reasoning often involves some notion that important info which might otherwise appear in the CoT will get “moved into” the illegible latent recurrence. This makes sense if you hold capabilities constant, and compare two ~equivalent models with and without latent reasoning, where the former spends some test-time compute on illegible reasoning while the latter has to spend all its test-time compute on CoT.
However, capabilities will not in fact be constant! If you train a new model with latent reasoning, there’s nothing forcing you to do less CoT with it, even if you could “get away with” doing that and still match the capabilities of your old model. You are free to combine latent reasoning and CoT and see how well they stack, and perhaps they do in fact stack nicely. What ultimately matters is what ends up expressed in the CoT of the best model you can train using the amount of CoT that’s optimal for it – not whether some other, less capable model+CoT combination would have reached its distinct, worse-on-average conclusions in a more legible manner. (Note that you can always decrease legibility by just not using CoT, even with regular transformers – but of course there’s no reason to care that this option isn’t legible since it’s not on the capabilities frontier.)
This situation is somewhat analogous to what we already have with regular transformer scaling and CoT: presumably there are sequential reasoning problems which GPT-4 can do in one forward pass (just by doing some “step by step” thing across its many layers), but which GPT-3.5 could only do via CoT. However, this didn’t cause us to use less CoT as a result of the scale-up: why satisfy yourself with merely hitting GPT-3.5 quality in fewer (but more expensive) forward passes, when you can go ahead and tackle a whole new class of harder problems, the ones that even GPT-4 needs CoT for?[1]
Number 2 matters for hopefully obvious reasons: if we could just “go full RNN” with no downsides then of course that would be more expressive, but the fact that transformers don’t do so (and reap the vast compute-efficiency benefits of not doing so) accounts for much/most (all?) of their vast success. The question is not “are there benefits to latent recurrence?” (of course there are) but “when, if ever, do you want to spend the marginal unit of compute on latent recurrence?” If you can afford to pay for a Coconut-ized version of your transformer then you could just make a bigger transformer instead, etc.
Unfortunately, looking at these papers, I don’t see much evidence either way about these questions at a glance. Or at least nothing re: number 2. If I’m reading Table 2 in depth-recurrence paper correctly, their model gets much bigger gains from CoT on GSM8K than any of their baseline models (and the gains improve further with more latent reasoning!) – which seems encouraging re: number 1, but I’m wary of reading too much into it.
The analogy is inexact because GPT-4 still has only however many layers it has – a fixed constant – while depth-recurrent models can “just keep going.” My point is simply that even if you can “just keep going,” that doesn’t imply that the best way to spend the marginal unit of test-time compute is always on more depth rather than more sampled tokens.
Do we have any reason to think “more tokens” will actually have any advantages over “more depth” in practice? I’m not sure, but one way to think about the tradeoff is: latent reasoning replaces a narrow bottleneck that can be arbitrarily expanded with a much larger bottleneck that can’t scale with problem size. That is, depth-recurrence and similar approaches have the familiar old problem of RNNs, where they have to write all the intermediate results of their reasoning onto a fixed-length scratchpad, and hence will eventually have trouble with tasks of the form “compute N intermediate results and then do some aggregation over the whole collection” where N is problem-dependent and can grow arbitrarily large.
Relatedly, KV caches in transformers are huge, which of course has painful memory costs but does allow the transformer to store a ton of information about the tokens it generates, and to look up that information later with a great deal of precision.
So comparing the capacity of the hidden state (as the bottleneck for depth-recurrence) against the capacity of just the CoT tokens (as the bottleneck for transformer+CoT) isn’t really comparing apples to apples: while the transformer is much more limited in what information it can “directly pass along” from step to step (with that info immediately+fully available to all future operations), it always constructs very high-dimensional representations of each step which are visible at least to some operations inside subsequent steps, allowing the transformer to “write out a haystack and then find the needle in it” even if that needle is tough to discriminate from its many neighbors. (This argument is hand-wavey and so I’m not super confident of it, would be interesting to find out if it can be made more precise, or already has been)
Thanks, this is a helpful framing! Some responses to your thoughts:
I agree that an analysis where the capabilities are held constant doesn’t make sense when comparing just two models with very different architectures (and I’m guilty of using this frame in this way to answer my first question). However, I expect the constant-capabilities frame to be useful for comparing a larger pool of models with roughly similar core capabilities but different maximum serial reasoning depths.[1] In this case, it seems very important to ask: given that labs don’t want to significantly sacrifice on capabilities, what architecture has the weakest forward passes while still having acceptable capabilities? Even if it’s the case that latent reasoning and CoT usually behave like complements and the model with the most expressive forward passes uses a similar amount of CoT to the model with the least expressive forward passes on average, it seems to me that from a safety perspective, we should prefer the model with the least expressive forward passes (assuming that things like faithfulness are equal), since that model is less of a threat to form deceptive plans in an illegible way.
Yeah, this is a good observation. The baseline results for GSM8K that they present look weird, though. While I can see the small Pythia models being too dumb to get any benefit from CoT, it’s puzzling why the larger OLMo models don’t benefit from the use of CoT at all. I checked the OLMo papers and they don’t seem to mention GSM8K results without CoT, so I couldn’t verify the results that way. However, as a relevant data point, the average accuracy of the GPT-2 model used as the baseline in the COCONUT paper jumps from 16.5% without CoT to 42.9% with CoT (Table 1 in the COCONUT paper). Compared to this jump, the gains in the depth-recurrence paper aren’t that impressive. For COCONUT, the accuracy is 21.6% without CoT and 34.1% with CoT.
One might argue that I’m not really holding capabilities constant here, since the models with more expressive forward passes can always do whatever the weaker models can do with CoT and also have the benefits of a more expressive forward pass, but it seems plausible to me that there would be a set of models to choose from that have roughly the same effective capabilities, i.e. capabilities we care about. The models with larger maximum serial reasoning depths may have some unique advantages, such as an advanced ability to explore different solutions to a problem in parallel inside a single forward pass, but I can still see the core capabilities being the same.