In part 1, we implemented exact recursive Bayesian updates for linear regression, taking advantage of conjugate priors to derive closed-form posterior updates.
In part 2, we introduced variational inference as a way to approximate intractable posteriors in Bayesian neural networks. Specifically, we can approximate the true posterior p(θ∣D) with a distribution from our chosen ‘variational family’, we chose the family of all Gaussian distributions. In batch variational inference, we optimise the Evidence Lower Bound (ELBO) over the entire dataset to find the parameters {μ,Σ} such that our approximation is the best possible fit for the true posterior.
R-VGA extends this to the streaming setting which we saw in part 1, where new data (Xt,yt) arrives in batches with each timestep t. Our choice of a linear Gaussian model with Gaussian priors in part 1 resulted in an exact expression for our posterior but, as in part 2, things are more complicated when we opt for a non-linear neural network.
In this post I’ll follow the derivation from Lambert’s original paper [3] and will reproduce the experiment in section 2.4.4 from [4]. Namely, we’ll build a Bayesian neural network to model a streaming variation of the classic ‘make moons’ binary classification dataset.
A jupyter notebook version of the code in this post is available here.
The R-VGA algorithm ‘big picture’ is very similar to the story told in part 1: armed with some chosen measurement model and some Gaussian priors:
The previous posterior becomes our prior when updating under a new batch of data: qt−1(θ)=N(μt−1,Σt−1). If this is our first batch then we use our initial prior
Update the variational parameters: Find new μt,Σt by optimising the ELBO with respect to the new data
Repeat
The update this time is more involved. We start out with an expansion of the KL divergence term, taken verbatim from [3], below.
Letting Yt=(y1,…yt) denote the observations up to, and including time t and take qt(θ) to be our variational approximation to the true posterior at time t
Going line-by-line: we’ve used the definition of the KL divergence, split up Yt=(Yt−1,yt) and used Bayes theorem, used our assumption of the conditional independence of our yt and, lastly, multiplied top and bottom by p(Yt−1) and used the definition of conditional probability. So far, so good. Expanded the log term and the integral:
The last two integrands are independent of θ so we can extract the logarithms from the integral and in both cases we’re left with an integral over a distribution, yielding 1:
At this point the first step of the approximation is taken. We approximate that p(θ∣Yt−1)≈qt−1(θ). Lambert proves that if our variational family is conjugate to our likelihood then this approximation incurs the same error as in batch variational inference, and that in linear Gaussian models we recover the Kalman equations as seen in part 1, I won’t recount that in detail here but see [3] for details.
These two gradient-of-expectation shaped terms are going to be a nightmare to compute. Happily we can use the following two useful identities (see the appendix of [5] for a proof)
and now we have some expectation-of-gradient shaped terms which we can simply compute by approximating the expectation with a sum over samples and using jax.grad and jax.hessian (erm…actually this isn’t true! More on this in a moment).
There are a few sticking points with these we need to think about before we can move to the code…
Implicit vs. explicit schemes
Looking closely at the above update rules we can see that they are only implicitly defined, that is, the new parameters (μ,Σ) are also present on the right hand side, tucked away in the expectation taken over qt(μ,Σ).
Lambert suggests that we can simply cheat approximate a little further by taking the expectation over a distribution we expect to be close to qt(μ,Σ) but which we already know, namely, qt−1(μ,Σ). But he acknowledges that the implicit scheme is optimal, under the assumptions and approximations thus far, if we can live with it. While working on the following implementation I tried both and found (1) that the results were very similar and (2) that the performance cost of solving the implicit scheme was tolerable. As a result in the below code example you will see the implicit scheme implemented.
But how do we solve the implicit scheme? Using fixed-point iteration. For each new individual update (i.e. updating μt−1,Σt−1→μt,Σt) we perform multiple updates, slowly converging to the true distribution (μt,Σt). This is, I admit, a little hand-wavey and the code could be made more robust with checks performed on the convergence of our variational parameters, but in practice I found that a small fixed number of fixed-point iterations worked fine.
Computing Eqt
We need to compute two expectations over various derivatives of our model log-likelihood. Lambert’s paper shows that for some simple models such as logistic regression we can perform some pretty impressive gymnastics to compute this expectation without the need for further approximation, but here I have opted to perform a small Monte Carlo estimation for both of these expectation terms. We simply draw a fixed number of parameters from our variational distribution and approximate the expectation integral with a sum.
The Hessian
All of this is getting a bit complicated but as I mentioned above at least we can just use jax’s autodiff to compute the first and second derivative terms for us right? right?
Wrong!
For a non-linear model such as a neural network we aren’t guaranteed that the Hessian will be positive definite, which can (and did, when I didn’t spot this in the implementation!) cause headaches when we come to compute its inverse in the update equations. Instead we can use the Gauss-Newton approximation to the Hessian:
For Bernoulli observations with mean μi(β), the log-likelihood is given by
Two interesting things to note about this expression. (1) The first term is positive semidefinite by construction and (2) the second term is zero in expectation since E[yi−μi]=0…OK three things, (3) the second term also vanishes if our model is linear since ∂β2∂2μi(β)=0, but that will not be the case for the models under consideration here.
The Gauss-Newton approximation for the Hessian simply drops the second term, leading to a form which is correct in expectation and is positive semidefinite.
Finally we can actually try out the R-VGA algorithm on some data.
As mentioned in the introduction, I’ve tried to reproduce the experiment from section 2.4.4 Durán-Martín’s thesis [4]. This experiment seeks to model the binary classification ‘make moons’ dataset where batches of randomly sampled data is streamed to us and we are allowed to update our model parameters each time.
And our Bayesian neural network model will be a simple one-layer MLP. As in previous parts of this mini-series I’m using jax and equinox to implement these models. Part 2 contains some details on the implementation of variational inference using jax/equinox, in case you missed it…
The model then, and some small utilities which will be familiar from last time:
@eqx.filter_jitdefr_vga_update(posterior:GaussianPosterior,unravel,X:jax.Array,y:jax.Array,*,N_fixed_point:int=10,N_monte_carlo:int=10,key,):mu,sigma=posteriorN_batch,_X.shapedefloglik(vec_params):net=unravel(vec_params)logits=jax.vmap(net)(X)logp=-jax.nn.softplus(-logits)log1mp=-logits-jax.nn.softplus(-logits)returnjnp.sum(jnp.where(y,logp,log1mp))deflogits_fn(v):returnjax.vmap(unravel(v))(X)# Make a 'prequential' prediction: Given our current parameters# mu and sigma, predict for the new minibatch (X, y)nll_loss=-loglik(mu)accuracy=jnp.mean((jax.nn.sigmoid(logits_fn(mu))>=0.5)==y)D=mu.sizedeffixed_point_itr(carry,key):# Sample model parameters via the Cholesky decomposition# we do so N_monte_carlo times to estimate the expectations# as simple sumsmu_curr,sigma_curr=carrysigma_curr=0.5*(sigma_curr+sigma_curr.T)jitter=eps*jnp.trace(sigma_curr)/sigma_curr.shape[0]L=jnp.linalg.cholesky(sigma_curr+jitter*jnp.eye(D))k,key=jax.random.split(key)z=jax.random.normal(k,(N_monte_carlo,D))theta=mu_curr+z@L.Tdefgauss_newton_hessian(theta_):J=jax.jacrev(logits_fn)(theta_)probs=jax.nn.sigmoid(logits_fn(theta_))W=probs*(1.0-probs)H_GN=J.T@(W[:,None]*J)returnH_GNmean_gradient=jax.vmap(eqx.filter_grad(loglik))(theta).mean(axis=0)mean_hessian=jax.vmap(gauss_newton_hessian)(theta).mean(axis=0)# Update means:# N.B. including a scaling factor to get# the correct mean update across a batchmu_new=mu_curr+sigma_curr@mean_gradient/N_batch# Update covariance:lambda_curr=jnp.linalg.inv(sigma_curr+jitter)lambda_new=lambda_curr+mean_hessian/N_batchsigma_new=jnp.linalg.inv(lambda_new)return(mu_new,sigma_new),key# Iterate fixed-points using a scan rather than a loopkeys=jax.random.split(key,N_fixed_point)(mu_final,sigma_final),_=jax.lax.scan(fixed_point_itr,(mu,sigma),keys,)returnGaussianPosterior(mu_final,sigma_final),nll_loss,accuracy
In jax it is usual to express for-loop style logic using a scan. scan allows the ‘carrying’ of state through the ’loop’ which in this case is a tuple of our variational arguments as we move through the fixed point iterations. The second argument is normally some sort of state which is aggregating across loop iterations, but here I’m using it to distribute a jax pseudorandom key to keep all of our sampling independent.
The expectation-of-gradients is happening quite quietly: I’m sampling not one set of parameters for our model at each fixed-point iteration, but N_monte_carlo copies. Then when we need the gradient and hessian of the model for the update we vectorise the computation over the Monte Carlo dimension and take the mean.
Now all we need is a small training loop and we can give it a try:
Running this, and looking at the ‘prequential’ (i.e. predict one timestep ahead) negative log-likelihood (NLL) and accuracy we can see that it looks like it’s been able to learn something!
To reinforce an important point here: the x-axis above is showing timesteps through our online learning process, not epochs: This is not the usual plot we’ve all seen and made after training a model: we’re not repeatedly showing the model the data and seeing it improve at predicting the full dataset.
Instead, we’re observing the model updating its beliefs after being shown small batches of data at each timestep, and only those data - not the preceding data, and gradually improving its understanding of the underlying dataset. Very cool.
We can also look at the decision boundary learned by the model through time:
In that last plot I chose to plot the mean prediction at each point on the input domain. But our model is still a Bayesian neural network so we could, if we liked, ask it to provide a full prediction distribution for any given point on our input domain.
In this post, we’ve built on the two previous posts to combine ideas from exact recursive Bayesian inference (Part 1) and variational inference for neural networks (Part 2) in order to implement the Recursive Variational Gaussian Approximation (R-VGA) algorithm.
In the next part, the last part of this mini series, we’ll extend the R-VGA model to describe non-stationary data. Somewhat surprisingly this isn’t as hard as it sounds, the heavy lifting was covered in this post, don’t worry :)