Basic Mathematics of Predictive Coding

This is an overview of the classic 1999 paper by Rajesh Rao and David Ballard that introduced predictive coding. I’m going to focus on explaining the mathematical setup instead of just staying at a conceptual level. I’ll do this in a series of steps of increasing detail and handholding, so that you can grasp the concepts by reading quickly, assuming you are familiar with the mathematical concepts. In addition I have implemented a convnet version of this framework in pytorch, which you can look at in this notebook:

Why I wrote this

The phrases “predictive coding” and “predictive processing” have squeezed their way outside of academia and into the public. They occur quite commonly on Lesswrong, and in related spaces. They are used to point towards a set of ideas that sound something like “the brain is trying to predict sensory input” or even more handwavy “everything about the mind and psychology is about prediction.” This makes predictive coding[1] sound like a big theory of everything, but it was actually introduced into neuroscience as a very concrete mathematical model for the visual cortex.

If you are interested in these ideas, it is helpful to understand the mathematical framework that underlies the more grandiose statements. If nothing else, it grounds your thinking, and it potentially gives you a real mathematical playground to test any ideas you might have about predictive coding. Additionally, 1999 was quite a while ago, and in many ways I think the mathematics of predictive coding makes more intuitive sense in today’s deep learning ways of thinking, than in the way originally presented. I’ve tried to port the original work into something that looks more at home in today’s ways of thinking (e.g. the code on github i linked to implements a predictive coding network in pytorch and uses convnet structure etc.)

A note of clarification: there are other network implementations of predictive processing, but this is the original one that started everything. Karl Friston has said that this is the paper that led him to active inference.

A second note of clarification: My understanding of these terms is that predictive processing is an umbrella term that encompasses especially the more theory of everything type ideas, while predictive coding really refers to (a small set) of particular network architectures, like the one described here. But I do think people use these terms quite sloppily, even in academic literature.

Summary of the mathematical framework

Don’t worry if this doesn’t make sense to you right now, we’ll be going through every piece. In order to get a predictive coding network we:

  • Write down a model of how causes in the world relate to retinal images

  • Use Bayes’ Rule to write out an equation for the posterior probabilities of causes given an image

  • Calculate the gradient descent rules for performing the optimization to maximize the posterior

  • Interpret those gradients as the dynamic equations for a recurrent neural network

The Generative Model

The goal is to make a network that takes in images and represents the image in the activations of the network. More specifically, we think of the network as representing the causes of the image. For example, a cow is a physical object that exists out in the world. When we see a cow, photons bouncing off the cow hit our retina and our brains try to infer the physical setting of the world which caused the retinal image, i.e. a cow. We are not after a representation of the retinal image itself, but instead of the cow which caused the retinal image.

The retinal image we call , and is represented as a vector of pixel values.

The causes of retinal images are given by a matrix and a vector . The columns of are a basis for the causes. We think of as the representation of a cause in the basis given by . To spell this out a bit more, the numbers in are the weights associated with each basis cause (i.e. each column in ), and the multiplication is a linear combination of basis causes. There is also a nonlinearity, , such that we assume images are related to causes in the following way

where is a noise vector that accounts for any error between the image and the inferred causes. We call this equation the generative model since it describes how causes in the world generate retinal images.

The rest of predictive coding (in the style of Rao and Ballard) is taking that generative model and applying Bayes Theorem to it, and then using gradient descent to maximize the posterior given a dataset. The gradient descent equations define a dynamics which are then interpreted as the dynamic activity of a recurrent neural network. And that’s basically it! If you aren’t interested in more detail or seeing the actual equations worked out, you can stop here, or skip to the simulation results.

A Fast Mathematical Tour from the Generative Model to Predictive Coding

A Simplified Non-heirarchical Version of Predictive Coding

  1. Start with the generative model, .

  2. Now compute the posterior using Bayes rule: . We have probability distributions since for a given and , we still have noise from the noise term. We assume the noise is normally distributed with 0 mean and identity covariance structure.

  3. We want to find the optimal that forms a good basis for our set of images, and given a particular image, a setting of that optimally represents the image in our basis. To do this we maximize the posterior, which is the same as minimizing the negative log of the posterior. Churning out the math we get that the negative log of the posterior is: , where those last two terms are negative logs of the priors.

  4. Now we calculate derivatives of the negative log posterior with respect to and and follow these gradients over time to find our optimal parameters. gradient descent on these equations to find optimal values of our parameters. The discrete time implementation of this process gives us the following recurrent equations: , and .

Interesting terms in these equations:

  • : the prediction of the image

  • : the prediction error

Before we go into how these terms are interpreted as a recurrent neural network, let’s extend the model to make it heirarchical.

Hierarchical Predictive Coding

  1. Start with the generative model as before, , but now make the vector itself the input to a higher order system which uses the generative model . So in total we have a hierarchical generative model: .

  2. Now compute the posterior , using Bayes’ rule, as before.

  3. Compute the negative log of the posterior, as before.

  4. Now we perform gradient descent on these equations as in the previous step and get our discrete time equations: , and .

Interesting terms in these equations:

  • : the low level prediction of the image

  • : the low level prediction error

  • : the high level prediction

  • : the higher level prediction error

Interpreting these equations as a recurrent neural network

We can visualize this hierarchical network as follows:

Hopefully by following the arrows and looking at the equation its pretty clear how the equations relate to the recurrent network. Here is a copy paste from a recent review by Jiang and Rao (2022).

More detail about the process

The setup is we are given a set of images (we will call a particular image ) and want to infer causes over the dataset and an for each image. We can use Bayesian inference to do this! Using the generative model, we can compute a likelihood—the probability of an image given causes, . This is a distribution because of the noise term in the generative model. That is, given a particular setting of both we still have a distribution over because of the randomness of . Ultimately we wish to find causes that maximize the posterior: , in other words we want to find the causes that are most probable given an image. Using Bayes theorem we have

Where and are priors on the causes. If we find the values of the causes that maximize the right hand side of the equation, we will maximize the left hand side. This is the approach we will take.

Computing likelihoods and priors

Remember the generative model has a noise term

We assume that is normally distributed with 0 mean and variance , and with no covariance structure. Since normal distributions have exponents in them, we will take logarithms. The logarithm is monotonically increasing, so maximizing is the same as maximizing . Taking logarithms has the added bonus that the multiplication of the likelihood and the priors becomes an addition. By convention we like minimizing functions instead of maximizing, so we will also take the negative logarithm of our inference equation, and we will call that :

We want to find and that minimize !

I’m not going to go through the details of the math here, but using the equation for normal distributions we get the following form for

Where the functions are the negative logarithms of the priors on and respectively.

Making the Network Hierarchical

We assumed that our generative model took the form of , which is to say that images are caused by causes. But what if those causes are themselves caused by more abstract higher-level causes, and .

In this way, we treat the causes, , of the retinal image as if it were sensory input to a more abstract system. This more abstracted system is trying to infer higher-level causes, in a basis given by , of . Now the overall posterior will be . As before, we can similarly derive an overall

Finding , and that minimize will be the same as maximizing the posterior.

We can generalize this situation to add more levels to the heirarchy:

...

Minimizing

Finding the representation of a specific image

To minimize we will take derivatives and then use gradient descent, like one is used to in deep learning. Let’s start with the gradient minimizing with respect to .

This equation is certainly a bit messy. Let’s go through it slowly. The first thing that jumps out is the term, . This is the difference between the actual image and the image estimated by the current setting of . In other words, this is how much error there is between your prediction of what the image is based on your current guess of what the cause is, and the actual image, and so we call, , the prediction error.

Similarly, , is the error between the current setting of and the prediction of given by the higher-level causes. This is a higher-level prediction error.

Note that both the base-level and higher-level prediction error terms are (inversely) scaled by a variance term , called the precision. The smaller your variance for the likelihood, the more precise the prediction error is, and so the relative contribution of that prediction error goes up. A lot of interesting ideas about how predictive coding is related to psychological phenomenon (e.g. mental illness) has to do with errors and manipulations of precision in the brain.

We have the following terms of interest, at each level of the heirarchy:

  • - the representation to be explained/​predicted

  • - the prediction of the lower level representation

  • - the prediction error

From the dynamical equation we can have a matrix −2/​sigma2 U^T multiplied by the prediction error,

Note: I haven’t finished this section but if you are following up to this point I think you probably get the idea. For the purposes of actually getting this out I’m leaving this as is, sorry.

Comparisons to standard deep-learning systems

Gradient descent is used to minimize loss in two ways

In the standard deep-learning setup gradient descent is used to find parameters of the network that minimize a loss function. The same thing happens in predictive coding networks, but it actually happens twice.

The first way gradient descent is used is the “normal” way—to learn features of the input that are useful for minimizing the loss. In our case this is how we change the parameters in the matrices. These changes are computed over batches of many images.

The second way gradient descent is used is to find the activations of the neural network, in this case, to find the values of the , given a single input image. This is not standard from the perspective of modern deep learning, but should set off mesa-optimizer alarm bells.

Training and inference are both dynamic

To expand on the previous point, in the standard case we use gradient descent for learning over a dataset, in other words for training our neural network. Thus, we can think of training as being a dynamical process. After that process, the parameters of the model are frozen and any new input can simply be run through the network to create an output. Thus, inference is simply a bunch of matrix multiplication and a few nonlinearities in a fixed network, and (usually) does not have much in the way of dynamics.

In predictive coding, inference is also a dynamic process. Even after training is over, for a given input image, we obtain the activations of our network by applying gradient descent dynamics.

We can think of training to be a long timescale process, and inference a short timescale process. While this is kind of the case in the standard setting, it only is in a trivial way, since standard inference usually doesn’t have any notion of time or dynamics associated with it.

This is kinda sorta like an autoencoder

At least conceptually, though I do believe I’ve seem some work formally relating them, but I can’t find the reference right now. In an autoencoder, you train the output to reproduce the input, after forcing all the input information through a bottleneck/​latent space. In the predictive coding network, the network is, by virtue of its gradient descent dynamics to maximize the posterior, trying to reproduce the sensory input at the earliest layers of processing, via a heirarchical set of feedback connections which act as predictions. These predictions continually get refined by virtue of the recurrent dynamics in the network. It’s like an autoencoder folded in on itself, with a bit more dynamics.

Some simple code/​simulations

I’ve implemented a convnet pytorch version of this predictive coding network, which you can see in this notebook:

You can see the hyperparameters I’ve chosen in the model instantiation

model = PredictiveCodingNetwork(input_size=(3, image_shape[1], image_shape[2]),               		    n_layers=2, n_causes=[8, 10], kernel_size=[[8,8],[3,3]], stride=[4, 2], padding=0, lam=0.1, alpha=0.1, k1=.005, k2=0.05, sigma2=10.)

The main ones are that there are 2 layers, with 8 basis causes in the first layer and 10 in the second, and that each cause in the first layer is a kernel of size 8x8 pixels, and each in the second layer is a kernel with size 3x3 (which is no longer in units of pixels but instead of first layer kernels).

With this we can train on a set of images, and then look at the input image, the prediction by the network (in the first layer, after multiple timesteps) and the difference, which I’m showing in the 3 columns of this figure:

One can also visualize the basis set in the first layer:

Obviously there’s a lot more one can do. For instance, one can look at how the predictions change and converge over time, or can visualize predictions in higher layers, or can intervene on the network at higher levels in order to manipulate predictions from the top-down, etc.

  1. ^