Deconfusing Direct vs Amortised Optimization

This post is part of the work done at Conjecture.

An earlier version of this post was posted here.

Many thanks go to Eric Winsor, Daniel Braun, Chris Scammell, and Sid Black who offered feedback on this post.

TLDR: We present a distinction from the Bayesian/​variational inference literature of direct vs amortized optimization. Direct optimizers apply optimization power to argmax some specific loss or reward function. Amortized optimizers instead try to learn a mapping between inputs and output solutions and essentially optimize for the posterior over such potential functions. In an RL context, direct optimizers can be thought of as AIXI-like planners which explicitly select actions by assessing the utility of specific trajectories. Amortized optimizers correspond to model-free RL methods such as Q learning or policy gradients which use reward functions only as a source of updates to an amortized policy/​Q-function. These different types of optimizers likely have distinct alignment properties: ‘Classical’ alignment work focuses on difficulties of aligning AIXI-like direct optimizers. The intuitions of shard theory are built around describing amortized optimizers. We argue that AGI, like humans, will probably be comprised of some combination of direct and amortized optimizers due to the intrinsic computational efficiency and benefits of the combination.

Here, I want to present a new frame on different types of optimization, with the goal of helping deconfuse some of the discussions in AI safety around questions like whether RL agents directly optimize for reward, and whether generative models (i.e. simulators) are likely to develop agency. The key distinction I want to make is between direct and amortized optimization.

Direct optimization is what AI safety people, following from Eliezer’s early depictions, often envisage an AGI as primarily being engaged in. Direct optimization occurs when optimization power is applied immediately and directly when engaged with a new situation to explicitly compute an on-the-fly optimal response – for instance, when directly optimizing against some kind of reward function. The classic example of this is planning and Monte-Carlo-Tree-Search (MCTS) algorithms where, given a situation, the agent will unroll the tree of all possible moves to varying depth and then directly optimize for the best action in this tree. Crucially, this tree is constructed ‘on the fly’ during the decision of a single move. Effectively unlimited optimization power can be brought to play here since, with enough compute and time, the tree can be searched to any depth.

Amortized optimization, on the other hand, is not directly applied to any specific problem or state. Instead, an agent is given a dataset of input data and successful solutions, and then learns a function approximator that maps directly from the input data to the correct solution. Once this function approximator is learnt, solving a novel problem then looks like using the function approximator to generalize across solution space rather than directly solving the problem. The term amortized comes from the notion of amortized inference, where the ‘solutions’ the function approximator learns are the correct parameters of the posterior distribution. The idea is that, while amassing this dataset of correct solutions and learning function approximator over it is more expensive, once it is learnt, the cost of a new ‘inference’ is very cheap. Hence, if you do enough inferences, you can ‘amortize’ the cost of creating the dataset.

Mathematically, direct optimization is your standard AIXI-like optimization process. For instance, suppose we are doing direct variational inference optimization to find a Bayesian posterior parameter from a data-point , the mathematical representation of this is:

By contrast, the amortized objective optimizes some other set of parameters $\phi$ over a function approximator which directly maps from the data-point to an estimate of the posterior parameters We then optimize the parameters of the function approximator across a whole dataset of data-point and parameter examples.

Where is the amortized loss function.

Amortized optimization has two major practical advantages. Firstly, it converts an inference or optimization problem into a supervised learning problem. Inference is often very challenging for current algorithms, especially in unbounded domains (and I suspect this is a general feature of computational complexity theory and unlikely to be just solved with a clever algorithm), while we know how to do supervised learning very well. Secondly, as the name suggests, amortized optimization is often much cheaper at runtime, since all that is needed is a forward pass through the function approximator rather than an explicit solution of an optimization problem for each novel data-point.

The key challenge of amortized optimization is in obtaining a dataset of solutions in the first place. In the case of supervised learning, we assume we have class labels which can be interpreted as the ‘ideal’ posterior over the class. In unsupervised learning, we define some proxy task that generates such class labels for us, such as autoregressive decoding. In reinforcement learning, this is more challenging, and instead we must use proxy measures such as the Bellman update for temporal difference based approaches (with underappreciated and often unfortunate consequences), or mathematical tricks to let us estimate the gradient without ever computing the solution as in policy gradients, which often comes at the cost of high variance.

Another way to understand this distinction is to think about the limit of infinite compute, where direct and amortized optimizers converge to different solutions. A direct optimizer, in the infinite compute limit, will simply find the optimal solution to the problem. An amortized optimizer would find the Bayes-optimal posterior over the solution space given its input data.

Almost all of contemporary machine learning, and especially generative modelling, takes place within the paradigm of amortized optimization, to the point that, for someone steeped in machine learning, it can be hard to realize that other approaches exist. Essentially all supervised learning is amortized: ‘inference’ in a neural network is performed in a forward pass which directly maps the data to the parameters of a probability distribution (typically logits of a categorical) over a class label [1][^1]. In reinforcement learning, where direct optimization is still used, the distinction is closely related to model-free (amortized) vs model-based (direct) methods. Model-free methods learn an (amortized) parametrized value function or policy—i.e. use a neural network to map from observations to either values or actions. Model-based methods on the other hand typically perform planning or model-predictive-control (MPC) which involves direct optimization over actions at each time-step of the environment. In general, research has found that while extremely effective in narrow and known domains such as board games, direct optimization appears to struggle substantially more in domains with very large state and action spaces (and hence tree branching width), as well as domains with significant amounts of stochasticity and partial observability, since planning under belief states is vastly more computationally taxing than working with the true MDP state. Planning also struggles in continuous-action domains where MCTS cannot really be applied and there are not really any good continuous-action planning algorithms yet known [2].

With recent work, however, this gap is closing and direct optimization, typically (and confusingly) referred to as model based RL is catching up to amortized. However, all of these methods almost always use some combination of both direct and amortized approaches. Typically, what you do is learn an amortized policy or value function and then use the amortized prediction to initialize the direct optimizer (which is typically a planner). This has the advantage of starting off the planner with a good initialization around what are likely to be decent actions already. In MCTS you can also short circuit the estimation of the value of an MCTS node by using the amortized value function as your estimate. These hybrid techniques vastly improve the efficiency of direct optimization and are widely used. For instance, alpha-go and efficient-zero both make heavy use of amortization in these exact ways despite their cores of direct optimization.

Relevance for alignment

The reason that it is important to carefully understand this distinction is that direct and amortized optimization methods seem likely to differ substantially in their safety properties and capabilities for alignment. A direct optimizer such as AIXI or any MCTS planner can, with enough compute, exhibit behaviour that diverges arbitrarily from its previous behaviour. The primary constraint upon its intelligence is the compute and time needed to crunch through an exponentially growing search tree. The out-of-distribution capabilities of an amortized agent, however, depend entirely on the generalization capabilities of the underlying function approximator used to perform the amortization. In the case of current neural networks, these almost certainly cannot accurately generalize arbitrarily far outside of their training distribution, and there are indeed good reasons for suspecting that this is a general limitation of function approximation [3]. A secondary key limitation of the capabilities of an amortized agent is in its training data (since an amortized method effectively learns the probability distribution of solutions) and hence amortized approaches necessarily have poor sample efficiency asymptotically compared to direct optimizers which theoretically need very little data to attain superhuman capabilities. For instance, a chess MCTS program needs nothing but the rules of chess and a very large amount of compute to achieve arbitrarily good performance while an amortized chess agent would have to see millions of games.

Moreover, the way scaling with compute occurs in amortized vs direct optimization seems likely to differ. Amortized optimization is fundamentally about modelling a probability distribution given some dataset. The optimal outcome here is simply the exact Bayesian posterior, and additional compute will simply be absorbed in better modelling of this posterior. If, due to the nature of the dataset, this posterior does not assign significant probability mass to unsafe outcomes, then in fact more computation and better algorithms should improve safety and the primary risk comes from misgeneralization—i.e. erroneously assigning probability mass to dangerous behaviours which are not as likely as in the true posterior. Moreover amortized optimizers are just generative models, it is highly likely that all amortized optimizers obey the same power-law scaling we observe in current generative modelling which means sharply diminishing (power law) returns on additional compute and data investment. This means that, at least in theory, the out of distribution behaviour of amortized agents can be precisely characterized even before deployment, and is likely to concentrate around previous behaviour. Moreover, the out of distribution generalization capabilities should scale in a predictable way with the capacity of the function approximator, of which we now have precise mathematical characterizations due to scaling laws.

The distinction between direct and amortized optimizers also clarifies what I think is the major conceptual distinction between perspectives such as shard theory vs ‘classical’ AGI models such as AIXI. Shard theory, and related works are primarily based around describing what amortized agents look like. Amortized agents do not explicitly optimize for rewards but rather repeat and generalize behaviours at test-time that led to reward in the past (for direct optimizers, however, reward is very much the optimization target). All the optimization occurred during training when a policy was learnt that attempted to maximize reward given a set of empirically known transitions. When this policy is applied to novel situations, however, the function approximator is not explicitly optimizing for reward, but instead just generalizing across ‘what the policy would do’. Thus, a key implicit claim of shard theory is that the AGIs that we build will end up looking much more like current model-free RL agents than planners like alpha-go and, ultimately, AIXI. Personally, I think something like this is quite likely due to the intrinsic computational and ontological difficulties with model-based planning in open-ended worlds which I will develop in a future post.

For alignment, my key contention is that we should be very aware of whether we are thinking of AGI systems as direct or amortized optimizers or some combination of both. Such systems would have potentially very different safety properties. Yudkowsky’s vision is essentially of a direct optimizer of unbounded power. For such an agent, indeed the only thing that matters and that we can control is its reward function, so alignment must focus entirely on the design of the reward function to be safe, corrigible etc. For amortized agents, however, the alignment problem looks very different. Here, while the design of the reward function is important, so too is the design of the dataset. Moreover, it seems likely that for such amortized agents we are much less likely to see sudden capability jumps with very little data, and so they are likely much safer overall. Such amortized agents are also much closer to the cognitive architecture of humans, which do not have fixed utility functions nor unbounded planning ability. It is therefore possible that we might be able to imprint upon them a general fuzzy notion of ‘human values’ in a way we cannot do with direct optimizers.

The fundamental question, then, is figuring out what is the likely shape of near-term AGI so we can adapt our alignment focus to it. Personally, I think that a primarily amortized hybrid architecture is most likely, since the computational advantages of amortization are so large, and that this appears to be how humans operate as well. However, epistemically, I am still highly uncertain on this point and things will clarify as we get closer to AGI.

Postscript 1: Are humans direct or amortized optimizers?

There is actually a large literature in cognitive science which studies this exact question, although typically under the nomenclature of model-based vs model-free reinforcement learners. The answer appears to be that humans are both. When performing tasks that are familiar, or when not concentrating, or under significant mental load (typically having to do multiple disparate tasks simultaneously), humans respond in an almost entirely amortized fashion. However, when faced with challenging novel tasks and have mental energy, humans are also capable of model-based planning like behaviour. These results heavily accord with (at least my) phenomenology, where usually we act ‘on autopilot’ however when we really want something we are capable of marshalling a significant amount of direct optimization power against a problem [4] [5]. Such a cognitive architecture makes sense for an agent with limited computational resources and, as discussed, such hybrid architectures are increasingly common in machine learning as well, at least for those that actually use direct optimization. However, while current approaches have a fixed architecture where amortization is always used in specific ways (i.e. to initialize a policy or to make value function estimates), humans appear to be able to flexibly shift between amortized and direct optimization according to task demands, novelty, and level of mental load.

Postscript 2: Mesaoptimizers

My epistemic status on this is fairly uncertain, but I think that this framing also gives some novel perspective on the question of mesaoptimization raised in the risks from learned optimization post. Using our terminology, we can understand the danger and likelihood of mesaoptimizers as the question of whether performing amortized optimization will tend to instantiate direct optimizers somewhere within the function approximator. The idea being, that for some problems, the best way to obtain the correct posterior parameters is to actually directly solve the problem using direct optimization. This is definitely a possibility, and may mean that we can unsuspectingly obtain direct optimizers from what appear to be amortized and hence we may overestimate the safety and underestimate the generalizability of our systems. I am pretty uncertain, but it feels sensible to me that the main constraint on mesa-optimization occuring is to what extent the architecture of the mapping function is conducive to the implementation of direct optimizers. Direct optimizers typically have a very specific architecture requiring substantial iteration and search. Luckily, it appears that our current NN architectures, with a fixed-length forward pass and a lack of recurrence or support for branching computations as is required in tree search makes the implementation of powerful mesa-optimizers inside the network quite challenging. However, this assessment may change as we scale up networks or continue to improve their architectures towards AGI.

To begin to understand the risk of mesaoptimization, it is important to make a distinction between within-forward-pass mesa-optimizers, and mesaoptimizers that could form across multiple forward passes. A mesaoptimizer within the forward pass would form when in order to implement some desired functionality, a direct optimizer has to be implemented within the amortized function. An early example of this sort potentially happening is the recent discovery that language models can learn to perform arbitrary linear regression problems (-/​cites), and hence potentially implement an internal iterative algorithm like gradient descent. Another possibility is that mesaoptimizers could form across multiple linked forward passes, when the forward passes can pass information to later ones. An example of this occurs in autoregressive decoding in language models, where earlier forward passes write output tokens which are then fed into later forward passes as inputs. It would be possible for a mesaoptimizer to form across forward passes by steganographically encoding additional information into the output tokens to be reused in later computations. While probably challenging to form in the standard autoregressive decoding task, such information transmission is almost entirely the point of other ideas like giving models access to ‘scratch-pads’ or external memory.

  1. ^

    People have also experimented with direct optimization even in perceptual tasks and found, unsurprisingly, that it can improve performance (although it may not be worth the additional compute cost).

  2. ^

    This is somewhere I suspect that we are not bottlenecked by fundamental reasons, but that we simply haven’t found the right algorithms or conceptualized things in the right way. I suspect that the underlying algorithmic truth is that continuous action spaces are harder by a constant factor, but not in qualitative terms the way it is now.

  3. ^

    Mathematically speaking, this is not actually a limitation but the desired behaviour. The goal of an amortized agent is to learn a posterior distribution over the solution ‘dataset’. More compute will simply result in better approximating this posterior. However, this posterior models the dataset and not the true solution space. If the dataset is not likely to contain the true solution to a problem, perhaps because it requires capabilities far in excess of any of the other solutions, it will not have high probability in the true posterior. This is exactly why alignment ideas like ‘ask GPT-N for a textbook from 2100 describing the solution to the alignment problem’ cannot work, even in theory. GPT is trained to estimate the posterior of sequences over its current dataset (the internet text corpus as of 2022 (or whenever)). The probability of such a dataset containing a true textbook from 2100 with a solution to alignment is 0. What does have probability is a bunch of text from humans speculating as to what such a solution would look like, given current knowledge, which may or may not be correct. Therefore, improving GPT by letting it produce better and better approximations to the posterior will not get us any closer to such a solution. The only way this could work is if GPT-N somehow misgeneralized its way into a correct solution and the likelihood of such a misgeneralization should ultimately decrease with greater capabilities. It is possible however that there is some capabilities sweet spot where GPT-N is powerful enough to figure out a solution but not powerful enough to correctly model the true posterior.

  4. ^

    This has also been noted with respect to language model samples (in a LW affiliated context).To nuance this: humans who are not concentrating are amortizing intelligences.

  5. ^

    Similarly to humans, we can clearly view ‘chain of thought’ style prompting in language models as eliciting a very basic direct optimization or planning capability which is probably a basic version of our natural human planning or focused attention capabilities. Equivalently, ‘prompt programming’ can be seen as just figuring out how best to query a world model and hence (haphazardly) applying optimization power to steer it towards its desired output. Neuroscientifically, I see human planning as occurring in a similar way with a neocortical world model which can be sequentially ‘queried’ by RL loops running through prefrontal cortex → basal ganglia → thalamus → sensory cortices.